mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
Merge 22fa685cc4 into 349e74ec7b
This commit is contained in:
commit
0305628b3d
@ -2,52 +2,124 @@ import torch
|
||||
from torch import optim, nn
|
||||
|
||||
|
||||
# 定义Lora网络结构
|
||||
class LoRA(nn.Module):
|
||||
"""
|
||||
Lora 网络结构
|
||||
"""
|
||||
def __init__(self, in_features, out_features, rank):
|
||||
super().__init__()
|
||||
self.rank = rank # LoRA的秩(rank),控制低秩矩阵的大小
|
||||
self.A = nn.Linear(in_features, rank, bias=False) # 低秩矩阵A
|
||||
self.B = nn.Linear(rank, out_features, bias=False) # 低秩矩阵B
|
||||
# 矩阵A高斯初始化
|
||||
# LoRA 的秩 (rank)
|
||||
self.rank = rank
|
||||
# 低秩矩阵 A & B
|
||||
self.A = nn.Linear(in_features, rank, bias=False)
|
||||
self.B = nn.Linear(rank, out_features, bias=False)
|
||||
# 低秩矩阵初始化
|
||||
# A: 高斯分布
|
||||
# B: 全零
|
||||
self.A.weight.data.normal_(mean=0.0, std=0.02)
|
||||
# 矩阵B全0初始化
|
||||
self.B.weight.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
# 前向传播
|
||||
return self.B(self.A(x))
|
||||
|
||||
|
||||
def apply_lora(model, rank=8):
|
||||
"""
|
||||
对指定的 model 添加 lora 层
|
||||
"""
|
||||
# 遍历模型的所有模块
|
||||
for name, module in model.named_modules():
|
||||
# 检查是否为 "线性层" 且 "权重矩阵为方阵"
|
||||
if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]:
|
||||
# 创建 LoRA 层并移动到模型设备上
|
||||
lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(model.device)
|
||||
# 针对 model 设定一个属性值: 属性: "lora" / 值: lora
|
||||
# https://docs.python.org/zh-cn/3.14/library/functions.html#setattr
|
||||
setattr(module, "lora", lora)
|
||||
|
||||
# 保存原始前向传播函数
|
||||
original_forward = module.forward
|
||||
|
||||
# 显式绑定
|
||||
def forward_with_lora(x, layer1=original_forward, layer2=lora):
|
||||
# 定义新的前向传播函数, 用于在原始前向传播基础上加入 LoRA 层
|
||||
# 这里显示绑定了 layer1=original_forward, layer2=lora 这两个参数
|
||||
def forward_with_lora(x, layer1=original_forward, layer2=lora):
|
||||
# 返回原始前向传播结果 & 加上 LoRA 层的输出
|
||||
return layer1(x) + layer2(x)
|
||||
|
||||
# 替换模块的前向传播函数
|
||||
module.forward = forward_with_lora
|
||||
|
||||
|
||||
def load_lora(model, path):
|
||||
state_dict = torch.load(path, map_location=model.device)
|
||||
state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()}
|
||||
"""
|
||||
加载 LoRA 权重文件并应用到模型
|
||||
|
||||
:param model: 包含初始化的 LoRA 层的模型对象, 即不包含训练权重, 只包含结构和空的 LoRA 参数
|
||||
:param path: LoRA 权重文件路径, 包含所有 A/B 矩阵的 state_dict
|
||||
"""
|
||||
# 从文件加载 state_dict, 从而获得完整的模型权重
|
||||
state_dict = torch.load(path, map_location=model.device)
|
||||
# 处理模型权重的键名, 移除 state_dict 中所有键名前的 "module." 前缀
|
||||
# 如果使用了 DP/DDP 这样的数据并行, 那么 pytorch 就会在模型的权重字典中添加 "module." 前缀来表示这些
|
||||
# 如果没有用 DP/DDP 来包装模型, 那么在使用 model.load_state_dict() 时就会报错
|
||||
# 因此这里移除掉 state_dict 中每层的 "module." 前缀, 因为这里 lora 微调时不会采用多卡
|
||||
state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()}
|
||||
# #### 代码等效 (推荐字典推导式, 这里仅作理解) ####
|
||||
# tmp_state_dict = {}
|
||||
# for k, v in state_dict.items():
|
||||
# if k.startswith('module.'):
|
||||
# # 移除前缀 "module."
|
||||
# k = k[7:]
|
||||
# tmp_state_dict[k] = v
|
||||
# state_dict = tmp_state_dict
|
||||
# ##############################################
|
||||
|
||||
# 遍历 model 的所有模块, 并用 state_dict 中的 数值替换 model 中的数值 (针对 lora 层)
|
||||
for name, module in model.named_modules():
|
||||
# 检查模块是否有 lora 属性
|
||||
# https://docs.python.org/zh-cn/3.14/library/functions.html#hasattr
|
||||
if hasattr(module, 'lora'):
|
||||
# 提取与当前模块相关的 LoRA 状态
|
||||
lora_state = {k.replace(f'{name}.lora.', ''): v for k, v in state_dict.items() if f'{name}.lora.' in k}
|
||||
# 加载 LoRA 层的状态字典
|
||||
module.lora.load_state_dict(lora_state)
|
||||
|
||||
|
||||
def save_lora(model, path):
|
||||
"""
|
||||
仅保存模型中 LoRA 适配器的权重参数, 不保存主干模型
|
||||
此函数适用于使用 Hugging Face PEFT 库进行 LoRA 微调的场景
|
||||
|
||||
:param model: 训练完成的模型对象 (可以是 PeftModel 或原始 nn.Module)
|
||||
:param path: 保存路径
|
||||
"""
|
||||
# 获取原始模型; 若 model 是 PEFT 包装后的 PeftModel
|
||||
# 则通过 _orig_mod 获取其封装的原始模型; 否则直接使用当前模型
|
||||
# https://docs.python.org/zh-cn/3.14/library/functions.html#getattr
|
||||
# getattr(object, name, default):
|
||||
# 如果对象有这个属性, 就返回它; 没有就返回默认值
|
||||
raw_model = getattr(model, '_orig_mod', model)
|
||||
# 在使用 Hugging Face 的 PEFT 库时, PEFT 会 "包装" 原始模型
|
||||
# 此时的 model 不再是原来的 AutoModelForCausalLM, 而是一个 PeftModel 类的对象
|
||||
# 它内部有一个属性 _orig_mod, 用来指向原始模型
|
||||
# 因此, 这行代码的实际作用如下:
|
||||
# 如果 model 是一个 PEFT 包装过的, 那么就取原始模型
|
||||
# 如果 model 没有经过 PEFT 包装, 则直接使用当前模型
|
||||
|
||||
# 创建一个空的状态字典用于存储模型
|
||||
state_dict = {}
|
||||
for name, module in raw_model.named_modules():
|
||||
# 检查模块是否有 lora 属性
|
||||
if hasattr(module, 'lora'):
|
||||
# 处理模块名称,去除 "module." 缀
|
||||
# 如果使用了 DP/DDP 这样的数据并行, 那么 pytorch 就会在模型的权重字典中添加 "module." 前缀来表示这些
|
||||
# 如果没有用 DP/DDP 来包装模型, 那么在使用 model.load_state_dict() 时就会报错
|
||||
clean_name = name[7:] if name.startswith("module.") else name
|
||||
# 构建 LoRA 层 的 lora_state
|
||||
lora_state = {f'{clean_name}.lora.{k}': v for k, v in module.lora.state_dict().items()}
|
||||
# 更新状态字典
|
||||
state_dict.update(lora_state)
|
||||
torch.save(state_dict, path)
|
||||
|
||||
# 保存状态字典到文件 (只保存 lora 层的权重)
|
||||
torch.save(state_dict, path)
|
||||
@ -1,11 +1,13 @@
|
||||
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
|
||||
# MiniMind Config
|
||||
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
|
||||
|
||||
#########################################
|
||||
# MiniMind Config
|
||||
#########################################
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class MiniMindConfig(PretrainedConfig):
|
||||
"""
|
||||
MiniMind Config, 继承自 HuggingFace 的 PretrainedConfig
|
||||
用于设置和管理模型的各种超参数和结构设置
|
||||
"""
|
||||
model_type = "minimind"
|
||||
|
||||
def __init__(
|
||||
@ -15,30 +17,30 @@ class MiniMindConfig(PretrainedConfig):
|
||||
eos_token_id: int = 2,
|
||||
hidden_act: str = 'silu',
|
||||
hidden_size: int = 512,
|
||||
intermediate_size: int = None,
|
||||
intermediate_size: int = None, # FFN 中间层大小, 推荐不用设置, 会自动计算
|
||||
max_position_embeddings: int = 32768,
|
||||
num_attention_heads: int = 8,
|
||||
num_attention_heads: int = 8, # 注意力头数, 也是 Query 的头数
|
||||
num_hidden_layers: int = 8,
|
||||
num_key_value_heads: int = 2,
|
||||
num_key_value_heads: int = 2, # Key / Value 的头数, 如果未设定则等于 num_attention_heads
|
||||
vocab_size: int = 6400,
|
||||
rms_norm_eps: float = 1e-05,
|
||||
rope_theta: int = 1000000.0,
|
||||
inference_rope_scaling: bool = False,
|
||||
flash_attn: bool = True,
|
||||
####################################################
|
||||
# Here are the specific configurations of MOE
|
||||
# When use_moe is false, the following is invalid
|
||||
# MOE 相关配置
|
||||
# 当 use_moe = false 时, 以下配置将无效
|
||||
####################################################
|
||||
use_moe: bool = False,
|
||||
num_experts_per_tok: int = 2,
|
||||
n_routed_experts: int = 4,
|
||||
n_shared_experts: int = 1,
|
||||
scoring_func: str = 'softmax',
|
||||
aux_loss_alpha: float = 0.01,
|
||||
seq_aux: bool = True,
|
||||
norm_topk_prob: bool = True,
|
||||
num_experts_per_tok: int = 2, # 每个 token 选择的专家数量
|
||||
n_routed_experts: int = 4, # 总的专家数量
|
||||
n_shared_experts: int = 1, # 共享专家
|
||||
scoring_func: str = 'softmax', # 评分函数, 默认 'softmax'
|
||||
aux_loss_alpha: float = 0.01, # 辅助损失的 alpha 参数
|
||||
seq_aux: bool = True, # 是否在序列级别上计算辅助损失
|
||||
norm_topk_prob: bool = True, # 是否标准化 top-k 概率, 推荐启用
|
||||
**kwargs
|
||||
):
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.dropout = dropout
|
||||
self.bos_token_id = bos_token_id
|
||||
@ -65,28 +67,29 @@ class MiniMindConfig(PretrainedConfig):
|
||||
} if self.inference_rope_scaling else None
|
||||
self.flash_attn = flash_attn
|
||||
####################################################
|
||||
# Here are the specific configurations of MOE
|
||||
# When use_moe is false, the following is invalid
|
||||
# MOE 相关配置
|
||||
# 当 use_moe = false 时, 以下配置将无效
|
||||
####################################################
|
||||
self.use_moe = use_moe
|
||||
self.num_experts_per_tok = num_experts_per_tok # 每个token选择的专家数量
|
||||
self.n_routed_experts = n_routed_experts # 总的专家数量
|
||||
self.n_shared_experts = n_shared_experts # 共享专家
|
||||
self.scoring_func = scoring_func # 评分函数,默认为'softmax'
|
||||
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
|
||||
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
|
||||
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
|
||||
self.num_experts_per_tok = num_experts_per_tok # 每个 token 选择的专家数量
|
||||
self.n_routed_experts = n_routed_experts # 总的专家数量
|
||||
self.n_shared_experts = n_shared_experts # 共享专家
|
||||
self.scoring_func = scoring_func # 评分函数, 默认 'softmax'
|
||||
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的 alpha 参数
|
||||
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
|
||||
self.norm_topk_prob = norm_topk_prob # 是否标准化 top-k 概率, 推荐启用
|
||||
|
||||
|
||||
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
|
||||
# MiniMind Model
|
||||
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
|
||||
|
||||
#########################################
|
||||
# MiniMind Model
|
||||
#########################################
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.init as init
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, Tuple, List, Union
|
||||
from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
|
||||
@ -94,20 +97,29 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
"""
|
||||
RMSNorm (Root Mean Square Normalization) 标准化层
|
||||
"""
|
||||
def __init__(self, dim: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
# 数值稳定性的小常数
|
||||
self.eps = eps
|
||||
# 可学习的权重参数
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
# 计算 RMS (Root Mean Square) 标准化
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
# 应用 RMS 标准化并乘以权重
|
||||
return self.weight * self._norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6,
|
||||
rope_scaling: Optional[dict] = None):
|
||||
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, rope_scaling: Optional[dict] = None):
|
||||
"""
|
||||
预计算 Rotary Position Embedding (RoPE) 的频率
|
||||
"""
|
||||
freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
|
||||
if rope_scaling is not None:
|
||||
orig_max, factor, beta_fast, beta_slow, attn_factor = (
|
||||
@ -116,6 +128,7 @@ def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float =
|
||||
)
|
||||
if end / orig_max > 1.0:
|
||||
# YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp
|
||||
# YaRN 缩放公式, 用于扩展位置编码的有效长度
|
||||
inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
|
||||
low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
|
||||
ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1)
|
||||
@ -129,7 +142,11 @@ def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float =
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""
|
||||
应用 Rotary Position Embedding (RoPE) 到 query 和 key
|
||||
"""
|
||||
def rotate_half(x):
|
||||
# 将张量的后半部分移到前半部分, 实现旋转
|
||||
return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
|
||||
|
||||
q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))
|
||||
@ -138,24 +155,57 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
|
||||
|
||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
||||
"""
|
||||
重复扩充 key/value 头的 dim, 使其数量与 query 的头数对齐, 用来支持 GQA (Grouped Query Attention)
|
||||
torch.repeat_interleave(x, dim=2, repeats=n_rep)
|
||||
https://docs.pytorch.ac.cn/docs/stable/generated/torch.repeat_interleave.html
|
||||
|
||||
:param: x: 一批 tensor 数据, shape 为 (bs, seq_len, num_key_value_heads, head_dim)
|
||||
:param: n_rep: 重复次数; 如果 query 有 32 个头, key/value 只有 4 个头, 则 n_rep = 8
|
||||
"""
|
||||
# 获取 x 的形状, 并分别赋予 bs, seq_len, num_key_value_heads, head_dim
|
||||
bs, slen, num_key_value_heads, head_dim = x.shape
|
||||
# n_rep = 1, 则不重复, 直接返回原数据
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return (
|
||||
x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
|
||||
)
|
||||
|
||||
# n_rep != 1, 则重复 kv 指定次数
|
||||
else:
|
||||
return x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
|
||||
# #### 等效代码 ####
|
||||
# # 在第三维上增加一个维度, 即 (bs, slen, num_key_value_heads, head_dim) -> (bs, slen, num_key_value_heads, 1, head_dim)
|
||||
# x = x[:, :, :, None, :]
|
||||
# # 将这个新增的维度扩展成 n_rep 大小 -> (bs, slen, num_key_value_heads, n_rep, head_dim)
|
||||
# x = x.expand(bs, slen, num_key_value_heads, n_rep, head_dim)
|
||||
# # 将维度重新调整为 -> (bs, slen, num_key_value_heads * n_rep, head_dim)
|
||||
# # 注意这里必须是 num_key_value_heads * n_rep, 不能反过来
|
||||
# # num_key_value_heads * n_rep 意味着 -> [kv0, kv0, kv1, kv1], 如果反过来就变为 -> [kv0, kv1, kv0, kv1]
|
||||
# # 这会导致 query0 找到了 kv0, 但同组的 query1 却跑去匹配了 kv1, 这违背了 GQA 的基本原理
|
||||
# x = x.reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
|
||||
# # 返回 x
|
||||
# return x
|
||||
# #################
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
Attention 模块, 支持 GQA (Grouped Query Attention)
|
||||
"""
|
||||
def __init__(self, args: MiniMindConfig):
|
||||
super().__init__()
|
||||
# 确定 kv 头的数量, 支持 Grouped Query Attention (GQA)
|
||||
# 如果 args.num_key_value_heads = None, 则 self.num_key_value_heads = args.num_attention_heads
|
||||
# 否则 self.num_key_value_heads = args.num_key_value_heads
|
||||
self.num_key_value_heads = args.num_attention_heads if args.num_key_value_heads is None else args.num_key_value_heads
|
||||
# 断言 args.num_attention_heads 必须能被 self.num_key_value_heads 整除
|
||||
assert args.num_attention_heads % self.num_key_value_heads == 0
|
||||
# Query 头数
|
||||
self.n_local_heads = args.num_attention_heads
|
||||
# Key / Value 头数
|
||||
self.n_local_kv_heads = self.num_key_value_heads
|
||||
# 重复因子
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
# 每个头的维度
|
||||
self.head_dim = args.hidden_size // args.num_attention_heads
|
||||
# 投影层: 将隐藏状态映射到 query、key、value
|
||||
self.q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
@ -163,137 +213,294 @@ class Attention(nn.Module):
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
self.resid_dropout = nn.Dropout(args.dropout)
|
||||
self.dropout = args.dropout
|
||||
# 检查是否支持Flash Attention
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
||||
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor], # 修改为接收cos和sin
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache=False,
|
||||
attention_mask: Optional[torch.Tensor] = None):
|
||||
bsz, seq_len, _ = x.shape
|
||||
xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor], # 接收 cos 和 sin -> cos, sin = position_embeddings
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # 用于缓存 kv 的变量
|
||||
use_cache=False, # 是否开启 kv 缓存功能
|
||||
attention_mask: Optional[torch.Tensor] = None # 注意力掩码矩阵, 形状为 (batch_size, seq_len)
|
||||
):
|
||||
# 获取 x 的维度信息
|
||||
bsz, seq_len, _ = x.shape
|
||||
# 投影到 query、key、value 空间
|
||||
xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
# 重塑为多头格式
|
||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
# 获取位置编码
|
||||
cos, sin = position_embeddings
|
||||
# 应用旋转位置编码 (qk需要, k不需要)
|
||||
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
|
||||
# kv_cache 实现: 拼接过去的 key/value
|
||||
if past_key_value is not None:
|
||||
xk = torch.cat([past_key_value[0], xk], dim=1)
|
||||
xv = torch.cat([past_key_value[1], xv], dim=1)
|
||||
past_kv = (xk, xv) if use_cache else None
|
||||
|
||||
# kv_cache实现
|
||||
if past_key_value is not None:
|
||||
xk = torch.cat([past_key_value[0], xk], dim=1)
|
||||
xv = torch.cat([past_key_value[1], xv], dim=1)
|
||||
past_kv = (xk, xv) if use_cache else None
|
||||
# 转置为注意力计算格式: (batch, heads, seq_len, head_dim)
|
||||
xq, xk, xv = (
|
||||
xq.transpose(1, 2),
|
||||
repeat_kv(xk, self.n_rep).transpose(1, 2),
|
||||
repeat_kv(xv, self.n_rep).transpose(1, 2)
|
||||
)
|
||||
|
||||
xq, xk, xv = (
|
||||
xq.transpose(1, 2),
|
||||
repeat_kv(xk, self.n_rep).transpose(1, 2),
|
||||
repeat_kv(xv, self.n_rep).transpose(1, 2)
|
||||
)
|
||||
# 使用 Flash Attention
|
||||
if self.flash and (seq_len > 1) and (past_key_value is None) and (attention_mask is None or torch.all(attention_mask == 1)):
|
||||
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
||||
# 其余情况使用 "标注注意力"
|
||||
else:
|
||||
# 计算注意力分数
|
||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
# 因果掩码: 防止看到未来信息
|
||||
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.triu.html
|
||||
# torch.triu(input, diagonal)
|
||||
# 返回矩阵 (2D 张量) 或矩阵批次的上三角部分 input, 结果张量 out 的其他元素将设置为 0
|
||||
# diagonal 控制要考虑的对角线
|
||||
# 如果 diagonal = 0, 则保留主对角线及其上方的所有元素
|
||||
# 正值会排除主对角线以上的相同数量的对角线
|
||||
# 负值会包含主对角线以下的相同数量的对角线
|
||||
scores[:, :, :, -seq_len:] += torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=scores.device), diagonal=1)
|
||||
|
||||
if self.flash and (seq_len > 1) and (past_key_value is None) and (attention_mask is None or torch.all(attention_mask == 1)):
|
||||
output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
||||
else:
|
||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
scores[:, :, :, -seq_len:] += torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=scores.device), diagonal=1)
|
||||
# attention_mask shape: (batch_size, seq_len)
|
||||
# - 1 表示该位置是有效 token
|
||||
# - 0 表示该位置是 padding
|
||||
if attention_mask is not None:
|
||||
# (bs, seq_len) -> (bs, 1, 1, seq_len)
|
||||
# 为了让掩码能够广播 (broadcast) 到注意力分数矩阵的维度上
|
||||
# 因为 scores shape: (bs, num_heads, seq_len, seq_len)
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
# -1e9 是一个极大的负数(-1000000000.0), 用于屏蔽无效位置
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
|
||||
# 应用掩码:
|
||||
# - 如果 extended_attention_mask = 1, 应用后对应位置 mask = 0 -> scores 不变 (保留)
|
||||
# - 如果 extended_attention_mask = 0, 应用后对应位置 mask 接近 -inf -> scores 分数会大幅拉低, 经过 Softmax ≈ 0
|
||||
scores = scores + extended_attention_mask
|
||||
|
||||
if attention_mask is not None:
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
|
||||
scores = scores + extended_attention_mask
|
||||
# Softmax 归一化
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
# Dropout
|
||||
scores = self.attn_dropout(scores)
|
||||
# 计算注意力输出 (乘上 Value)
|
||||
output = scores @ xv
|
||||
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
scores = self.attn_dropout(scores)
|
||||
output = scores @ xv
|
||||
|
||||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
output = self.resid_dropout(self.o_proj(output))
|
||||
return output, past_kv
|
||||
# 重塑并输出
|
||||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
output = self.resid_dropout(self.o_proj(output))
|
||||
return output, past_kv
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
"""
|
||||
前反馈神经网络
|
||||
"""
|
||||
def __init__(self, config: MiniMindConfig):
|
||||
super().__init__()
|
||||
# 计算 FFN 中间层大小, 确保是 64 的倍数, 是为了对 GPU 内存做对齐优化, 用来加速推理
|
||||
if config.intermediate_size is None:
|
||||
# SwiGLU FFN 中间层大小建议为 hidden_size * 8/3 (LLaMA 风格)
|
||||
intermediate_size = int(config.hidden_size * 8 / 3)
|
||||
# 对中间层大小进行 64 的倍数对齐: 向上取整到最近的 64 的倍数
|
||||
# - 公式: (x + n - 1) // n * n
|
||||
# - 加上 63 -> 对该值 "向下取整(//)" -> 乘上 64
|
||||
config.intermediate_size = 64 * ((intermediate_size + 64 - 1) // 64)
|
||||
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
||||
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
# SwiGLU FFN 的三个线性层
|
||||
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) # 门控层
|
||||
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) # 下投影层
|
||||
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) # 上投影层
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
# 激活函数, 默认为 SiLU
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py
|
||||
# from transformers.activations import ACT2FN
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
# SwiGLU: gate * up -> down
|
||||
return self.dropout(self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)))
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
"""
|
||||
MoE 门控网络
|
||||
|
||||
- 根据输入的隐藏状态, 动态地为每个 token 选择 top-k 个最合适的专家, 并输出其权重
|
||||
- 同时可选地计算一个负载均衡辅助损失来防止专家 "垄断"
|
||||
|
||||
输入:
|
||||
- hidden_states: (batch_size, seq_len, hidden_size)
|
||||
输出:
|
||||
- topk_idx: 每个 token 选择的 top-k 专家索引, 形状为 (bsz * seq_len, k)
|
||||
- topk_weight: 对应的专家权重 (归一化后), 形状同上
|
||||
- aux_loss: 辅助损失项 (用于训练时均衡专家负载)
|
||||
"""
|
||||
def __init__(self, config: MiniMindConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.top_k = config.num_experts_per_tok # 每个 token 选择的专家数量
|
||||
self.n_routed_experts = config.n_routed_experts # 路由专家总数
|
||||
|
||||
self.scoring_func = config.scoring_func
|
||||
self.alpha = config.aux_loss_alpha
|
||||
self.seq_aux = config.seq_aux
|
||||
self.scoring_func = config.scoring_func # 评分函数, 仅支持 softmax
|
||||
self.alpha = config.aux_loss_alpha # 辅助损失权重, 常见为 0.01 ~ 0.1
|
||||
self.seq_aux = config.seq_aux # 是否使用序列级辅助损失 (False 比较常见)
|
||||
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.gating_dim = config.hidden_size
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
||||
self.reset_parameters()
|
||||
self.norm_topk_prob = config.norm_topk_prob # 是否标准化 top-k 概率, 推荐启用
|
||||
self.gating_dim = config.hidden_size # 门控维度, 注意维度和隐藏层维度相同
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) # 门控权重
|
||||
self.reset_parameters() # Kaiming 初始化 self.weight 权重
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
# Kaiming 初始化权重
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
"""
|
||||
:param: hidden_states (bsz, seq_len, hidden_size)
|
||||
"""
|
||||
# 1.记录原始形状并展平, 将所有 token 视为独立样本
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
# shape: (bsz * seq_len, hidden_size)
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
|
||||
# 2.计算路由分数 (Gating Scores)
|
||||
# logits: (bsz * seq_len, n_routed_experts)
|
||||
# 这一步计算 token 与每个专家的相似度
|
||||
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.nn.functional.linear.html
|
||||
# - hidden_states: (bsz * seq_len, hidden_size)
|
||||
# - self.weight: (n_routed_experts, hidden_size) -> self.gating_dim = config.hidden_size
|
||||
# - logits: hidden_states @ weight.T = (bsz * seq_len, n_routed_experts)
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
|
||||
if self.scoring_func == 'softmax':
|
||||
# 使用 softmax 转换为概率分布, 和为1
|
||||
# scores: (bsz * seq_len, n_routed_experts)
|
||||
scores = logits.softmax(dim=-1)
|
||||
else:
|
||||
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
||||
|
||||
# 3.选择 top-k 专家
|
||||
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.topk.html
|
||||
# 给定 scores 张量在给定维度(dim=-1)上最大的 k 个元素 -> (values, indices)
|
||||
# - topk_weight: 对应专家概率 (bsz * seq_len, top_k)
|
||||
# - topk_idx: 对应专家索引 (bsz * seq_len, top_k)
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
# 4.标准化 top-k 概率
|
||||
# 如果选择了多个专家, 则将它们的概率重新归一化, 使其总和为 1
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
# 分母
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
# 权重再归一化后的概率
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
if self.training and self.alpha > 0.0:
|
||||
# 5.计算辅助损失 (用于负载均衡)
|
||||
if self.training and self.alpha > 0.0: # self.training 继承于 torch.nn.Module
|
||||
# (bsz * seq_len, n_routed_experts)
|
||||
scores_for_aux = scores
|
||||
aux_topk = self.top_k
|
||||
# (bsz * seq_len, top_k) -> (bsz, seq_len * top_k)
|
||||
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||
|
||||
# 分支1: 序列级辅助损失 (较少见)
|
||||
if self.seq_aux:
|
||||
# (bsz, seq_len, n_routed_experts)
|
||||
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||
# ce 是 Counter for Experts / Cumulative Expert usage count
|
||||
# ce shape (bsz, n_routed_experts)
|
||||
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||
ce.scatter_add_(1, topk_idx_for_aux_loss,
|
||||
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
|
||||
seq_len * aux_topk / self.n_routed_experts)
|
||||
|
||||
# 目标张量为 ce, 作用是统计每个 sequence 中各个专家被选中的频率
|
||||
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.Tensor.scatter_add_.html
|
||||
# - Tensor.scatter_add_(dim, index, src)
|
||||
# - 把 src[i] 的值, 累加到 "目标张量" 中由 index[i] 指定的 dim 维度上的位置
|
||||
# - self, index 和 src 必须具有相同的维度数
|
||||
|
||||
# div_(seq_len * aux_topk / self.n_routed_experts) 是归一化因子
|
||||
# 如果完美均衡, 每个专家应被选中 (seq_len * top_k) / n_routed_experts 次
|
||||
# 因此 ce.div_(xxx) 之后, ce[i][j] 变成 "归一化的期望频次"
|
||||
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.div.html
|
||||
# - torch.div(input, other, *, rounding_mode=None, out=None)
|
||||
# - 将输入 input 的每个元素除以 other 的相应元素
|
||||
ce.scatter_add_(
|
||||
# 在 "专家维度" 上累加
|
||||
dim = 1,
|
||||
# 每个 token 的 top-k 专家编号
|
||||
index = topk_idx_for_aux_loss,
|
||||
# 对每个选中操作, 这里为每次加 1
|
||||
src = torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)
|
||||
|
||||
# 计算辅助损失
|
||||
# 目标是最小化专家负载的不均衡性, 即让 f_i ≈ 1/n
|
||||
# 公式: aux_loss = alpha * (P_i * f_i).sum()
|
||||
# - P_i: 第 i 个专家的平均预测概率, 来自 scores_for_seq_aux.mean(dim=1)
|
||||
# - f_i: 第 i 个专家的实际被选中频率 (归一化后), 来自 ce
|
||||
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
|
||||
######################################################
|
||||
# scores_for_seq_aux.mean(dim=1)
|
||||
# -> 每个样本中, 各个专家的平均预测概率 (P_i)
|
||||
# ce * scores_for_seq_aux.mean(dim=1)
|
||||
# -> 每个样本中, 对每个专家的 频率 × 预测概率 (P_i * f_i)
|
||||
# .sum(dim=1).mean()
|
||||
# -> 每个样本的负载不均衡程度总和, 并平均到整个 batch
|
||||
# self.alpha
|
||||
# -> 缩放, 避免主导主 loss
|
||||
######################################################
|
||||
|
||||
# 分支2: Token 级/Batch 级辅助损失 (默认)
|
||||
else:
|
||||
# 1.计算 expert 被选中的实际频率 f_i
|
||||
# - mask_ce shape: (bsz * seq_len * n_routed_experts, n_routed_experts)
|
||||
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.nn.functional.one_hot.html
|
||||
# 接收形状为 (*) 的索引值的 LongTensor, 并返回形状为 (*, num_classes) 的张量
|
||||
# - 每行代表一个选择事件
|
||||
# - 每列代表一个专家 (num_classes=self.n_routed_experts)
|
||||
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||||
# 每个专家在 batch 中被选中的比例
|
||||
# ce shape: (n_routed_experts,)
|
||||
ce = mask_ce.float().mean(0)
|
||||
Pi = scores_for_aux.mean(0)
|
||||
# fi: 归一化后的频率, 理想情况下 fi 应接近 1
|
||||
fi = ce * self.n_routed_experts
|
||||
|
||||
# 2.计算 expert 的平均路由概率 P_i
|
||||
# Pi shape: (n_routed_experts,)
|
||||
Pi = scores_for_aux.mean(0)
|
||||
|
||||
# 3.计算辅助损失
|
||||
# 计算点积并求和: aux_loss = alpha * (P_i * f_i).sum()
|
||||
# 期望分布是均匀的, 这个 Loss 最小化时通常意味着均衡
|
||||
aux_loss = (Pi * fi).sum() * self.alpha
|
||||
|
||||
# 不需要计算辅助损失
|
||||
else:
|
||||
aux_loss = scores.new_zeros(1).squeeze()
|
||||
|
||||
# 返回结果
|
||||
# - topk_idx: 选中的专家索引: (bsz * seq_len, top_k)
|
||||
# - topk_weight: 选中的专家权重: (bsz * seq_len, top_k)
|
||||
# - aux_loss: 辅助损失, 用于负载均衡
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
|
||||
class MOEFeedForward(nn.Module):
|
||||
"""
|
||||
MOE 前反馈神经网络
|
||||
"""
|
||||
def __init__(self, config: MiniMindConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# 创建路由专家列表, 只有被 Gate 选中的才会被计算
|
||||
self.experts = nn.ModuleList([
|
||||
FeedForward(config)
|
||||
for _ in range(config.n_routed_experts)
|
||||
])
|
||||
# 门控网络: 输入是 x, 输出是 top-k 的专家索引和对应权重 (还有辅助损失 aux_loss)
|
||||
self.gate = MoEGate(config)
|
||||
# 创建共享专家列表 (可选)
|
||||
if config.n_shared_experts > 0:
|
||||
self.shared_experts = nn.ModuleList([
|
||||
FeedForward(config)
|
||||
@ -301,147 +508,351 @@ class MOEFeedForward(nn.Module):
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
:param: x: (bsz, seq_len, hidden_size)
|
||||
"""
|
||||
# 原始输入, 用于后续共享专家相加
|
||||
identity = x
|
||||
orig_shape = x.shape
|
||||
bsz, seq_len, _ = x.shape
|
||||
# 使用门控机制选择专家
|
||||
|
||||
# 通过 Gate 获取路由结果
|
||||
# - topk_idx: 选中的专家索引: (bsz * seq_len, top_k)
|
||||
# - topk_weight: 选中的专家概率: (bsz * seq_len, top_k)
|
||||
# - aux_loss: 辅助损失, 用于负载均衡
|
||||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
|
||||
# 准备数据
|
||||
x = x.view(-1, x.shape[-1]) # (bsz * seq_len, hidden_size)
|
||||
flat_topk_idx = topk_idx.view(-1) # (bsz * seq_len * top_k)
|
||||
|
||||
# 训练模式: 使用专家并行处理
|
||||
if self.training:
|
||||
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
|
||||
y = torch.empty_like(x, dtype=x.dtype)
|
||||
# 将 x 重复扩展以匹配专家
|
||||
# 如果 k=2, 每个 token 需要被处理两次
|
||||
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.repeat_interleave.html
|
||||
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.Tensor.repeat.html
|
||||
# - Tensor.repeat(*repeats)
|
||||
# - repeats: 每个元素重复的次数
|
||||
# - dim: 沿着指定维度进行重复
|
||||
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0) # (bsz * seq_len * top_k, hidden_size)
|
||||
|
||||
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.empty_like.html
|
||||
# 预分配输出空间 (返回一个大小与 x 相同的未初始化张量)
|
||||
y = torch.empty_like(x, dtype=x.dtype) # (bsz * seq_len * top_k, hidden_size)
|
||||
|
||||
# 遍历每一个专家
|
||||
for i, expert in enumerate(self.experts):
|
||||
expert_out = expert(x[flat_topk_idx == i])
|
||||
if expert_out.shape[0] > 0: y[flat_topk_idx == i] = expert_out.to(y.dtype)
|
||||
else: y[flat_topk_idx == i] = expert_out.to(y.dtype) + 0 * sum(p.sum() for p in expert.parameters())
|
||||
# 制作掩码: 找出所有分配给专家 i 的 token 位置
|
||||
mask = (flat_topk_idx == i)
|
||||
# x[mask]: 只属于该专家的输入
|
||||
expert_out = expert(x[mask])
|
||||
# 专家输出不为空时, 将专家输出写入到 y 中对应位置
|
||||
if expert_out.shape[0] > 0:
|
||||
y[mask] = expert_out.to(y.dtype)
|
||||
# 否则, 添加一个零梯度的项, 保证即使没有 token 被分配, 参数仍有梯度
|
||||
else:
|
||||
# 这里的 0 * sum(...) 确保该专家的参数在计算图中, 避免 DDP 报错 (unused parameters)
|
||||
y[mask] = expert_out.to(y.dtype) + 0 * sum(p.sum() for p in expert.parameters())
|
||||
|
||||
# 聚合结果
|
||||
# y 目前的形状是 (bsz * seq_len * top_k, hidden_size), 代表 "展开的" 专家输出
|
||||
# 需要将 y 加权合并为 -> (bsz * seq_len, hidden_size)
|
||||
# - topk_weight: (bsz * seq_len, top_k)
|
||||
# - y.view(*topk_weight.shape, -1): (bsz * seq_len, top_k, hidden_size)
|
||||
# - topk_weight.unsqueeze(-1): (bsz * seq_len, top_k, 1)
|
||||
# - 广播乘法 (N, k, h) * (N, k, 1) = (N, k, h)
|
||||
# - 最后在 top_k 维度上求和 (dim=1): (bsz * seq_len, top_k, hidden_size) -> (bsz * seq_len, hidden_size)
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
# y 已经是 (bsz, seq_len, hidden_size) 了, 这里的 view 只是为了保险
|
||||
y = y.view(*orig_shape)
|
||||
|
||||
# 推理模式: 使用优化的专家推理
|
||||
# - 使用优化后的串行处理 (减少显存占用, 避免大量 mask 操作)
|
||||
else:
|
||||
# flat_topk_idx: (bsz * seq_len * top_k)
|
||||
# topk_weight.view(-1, 1) 展平了所有维度: (bsz * seq_len * top_k, 1)
|
||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
|
||||
# 添加共享专家的输出
|
||||
if self.config.n_shared_experts > 0:
|
||||
for expert in self.shared_experts:
|
||||
y = y + expert(identity)
|
||||
self.aux_loss = aux_loss
|
||||
|
||||
# (bsz, seq_len, hidden_size)
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
"""
|
||||
推理模式下高效 MoE 实现
|
||||
不使用 repeat_interleave 复制 x (省显存), 而是利用索引重排
|
||||
|
||||
:param x: (bsz * seq_len, hidden_size)
|
||||
:param flat_expert_indices: (bsz * seq_len * top_k) 专家编号
|
||||
:param flat_expert_weights: (bsz * seq_len * top_k, 1) 对应的权重
|
||||
"""
|
||||
# 结果缓存 (bsz * seq_len, hidden_size)
|
||||
expert_cache = torch.zeros_like(x)
|
||||
|
||||
# 1.对所有 token 的任务分配进行排序
|
||||
# idxs 是排序后的索引, 能够把 "分配给专家0的任务", "分配给专家1的任务" 聚在一起
|
||||
idxs = flat_expert_indices.argsort()
|
||||
|
||||
# 2.计算每个专家处理多少个 token
|
||||
# - bincount 统计每个专家出现的次数
|
||||
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.bincount.html
|
||||
# - cumsum 返回在 dim 维度上的累积和, 用于切片
|
||||
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.cumsum.html
|
||||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||
|
||||
# 3.计算排序后的索引对应原始的 token_id
|
||||
# - flat_expert_indices 长度是 N*k, 是展开的 (bsz * seq_len * top_k)
|
||||
# - // top_k 操作将 "展开后的索引" 映射回 "原始 token 的行号"
|
||||
token_idxs = idxs // self.config.num_experts_per_tok
|
||||
# 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4)
|
||||
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
|
||||
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok)
|
||||
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
|
||||
#####################################
|
||||
# 当 tokens_per_expert = [6, 15, 20, 26], tokens_per_expert.shape[0] 即为专家数量(此时为4)
|
||||
# 且 token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
|
||||
# 意味 token_idxs[:6] -> [3, 7, 19, 21, 24, 25] 这6个位置属于 "专家0" 处理的 token (每个token有可能被多个专家处理, 这取决于 num_experts_per_tok)
|
||||
# 接下来9个位置 token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...] 属于 "专家1" 处理的 token...依此类推
|
||||
#####################################
|
||||
|
||||
# 遍历每个专家 (根据 tokens_per_expert 切片)
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
|
||||
# 该专家在这个 batch 没有任务
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens).to(expert_cache.dtype)
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
|
||||
|
||||
# 获取属于专家 i 的所有 token 的原始行号
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
# 从 x 中取出这些 token
|
||||
expert_tokens = x[exp_token_idx]
|
||||
# 前向计算
|
||||
expert_out = expert(expert_tokens).to(expert_cache.dtype)
|
||||
# 乘上对应的权重
|
||||
# - idxs[start_idx:end_idx] 取出的是 "flat_expert_weights" 对应的索引
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
|
||||
# 将结果累加回 expert_cache
|
||||
# scatter_add_ 处理这种情况: 如果一个 token 同时选择了专家 A 和 专家 B
|
||||
# 它的结果需要是 (OutA * wA) + (OutB * wB)
|
||||
# 这里利用 scatter_add 根据 token 索引将结果加回去
|
||||
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
|
||||
#####################################
|
||||
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.Tensor.scatter_add_.html
|
||||
# - Tensor.scatter_add_(dim, index, src)
|
||||
# - 把 src[i] 的值, 累加到 "目标张量" 中由 index[i] 指定的 dim 维度上的位置
|
||||
# - self, index 和 src 必须具有相同的维度数
|
||||
#####################################
|
||||
|
||||
# (bsz * seq_len, hidden_size)
|
||||
return expert_cache
|
||||
|
||||
|
||||
class MiniMindBlock(nn.Module):
|
||||
"""
|
||||
MiniMind 模型的一个 Transformer 块
|
||||
"""
|
||||
def __init__(self, layer_id: int, config: MiniMindConfig):
|
||||
super().__init__()
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
# 自注意力层 (Attention 模块)
|
||||
self.self_attn = Attention(config)
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
# 输入前标准化
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
# 注意力后标准化
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
# 根据配置选择FFN类型: 普通FFN / MoE FFN
|
||||
self.mlp = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
||||
|
||||
def forward(self, hidden_states, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None):
|
||||
"""
|
||||
:param hidden_states: (batch_size, seq_len, hidden_size)
|
||||
:param position_embeddings: Tuple[torch.Tensor, torch.Tensor]
|
||||
:param past_key_value: Tuple[torch.Tensor, torch.Tensor]
|
||||
:param use_cache: 是否开启 kv 缓存功能
|
||||
:param attention_mask: 注意力掩码矩阵, 形状为 (batch_size, seq_len)
|
||||
"""
|
||||
# 第一个残差连接:输入层标准化 -> 注意力 -> 残差连接
|
||||
residual = hidden_states
|
||||
hidden_states, present_key_value = self.self_attn(
|
||||
self.input_layernorm(hidden_states), position_embeddings,
|
||||
past_key_value, use_cache, attention_mask
|
||||
self.input_layernorm(hidden_states),
|
||||
position_embeddings,
|
||||
past_key_value,
|
||||
use_cache,
|
||||
attention_mask
|
||||
)
|
||||
hidden_states += residual
|
||||
|
||||
# 第二个残差连接:注意力后标准化 -> FFN -> 残差连接
|
||||
hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states))
|
||||
|
||||
# 返回 输出和缓存的键值对
|
||||
return hidden_states, present_key_value
|
||||
|
||||
|
||||
class MiniMindModel(nn.Module):
|
||||
"""
|
||||
MiniMind 模型主类
|
||||
"""
|
||||
def __init__(self, config: MiniMindConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size, self.num_hidden_layers = config.vocab_size, config.num_hidden_layers
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.vocab_size = config.vocab_size # 词表大小
|
||||
self.num_hidden_layers = config.num_hidden_layers # 隐藏层数量
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) # 词嵌入层
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
# 构建多个相同的 Transformer 块 (MiniMindBlock)
|
||||
# 每个块包含自注意力 + FFN, 通过 ModuleList 管理参数
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, config) for l in range(self.num_hidden_layers)])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # 最终标准化层
|
||||
|
||||
freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.hidden_size // config.num_attention_heads,
|
||||
end=config.max_position_embeddings, rope_base=config.rope_theta,
|
||||
rope_scaling=config.rope_scaling)
|
||||
# 预计算旋转位置编码
|
||||
freqs_cos, freqs_sin = precompute_freqs_cis(
|
||||
dim=config.hidden_size // config.num_attention_heads,
|
||||
end=config.max_position_embeddings, rope_base=config.rope_theta,
|
||||
rope_scaling=config.rope_scaling
|
||||
)
|
||||
# 注册为缓冲区, 不会被优化
|
||||
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
||||
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs):
|
||||
def forward(
|
||||
self,
|
||||
# 输入 token 序列, 形状:(batch_size, seq_len)
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
# 注意力掩码:1 表示有效 token, 0 表示 padding
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
# KV 缓存列表 (用于生成任务)
|
||||
# - 每个元素为 (key, value) 对
|
||||
# - 形状为 [(bs, seq_len_k, num_heads, head_dim), ...]
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
# 是否启用缓存 (推理加速)
|
||||
use_cache: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
# 提取输入维度
|
||||
batch_size, seq_length = input_ids.shape
|
||||
if hasattr(past_key_values, 'layers'): past_key_values = None
|
||||
|
||||
# KV 缓存处理
|
||||
#######################################################################################
|
||||
# https://docs.python.org/zh-cn/3.14/library/functions.html#hasattr
|
||||
# - 如果字符串是对象的属性之一的名称, 则返回 True, 否则返回 False
|
||||
if hasattr(past_key_values, 'layers'):
|
||||
# 如果传进来的 past_key_values 是一个带 .layers 属性的对象 (比如 HuggingFace 封装的输出)
|
||||
# 那就把它当作无效缓存处理, 清空为 None
|
||||
past_key_values = None
|
||||
|
||||
# 对 past_key_values 进行默认值填充
|
||||
# 如果是 None、空列表 [] 或其他 "假值(falsy)"
|
||||
# 就用一个长度为 num_hidden_layers 的全 None 列表替代 -> [None] * len(self.layers)
|
||||
past_key_values = past_key_values or [None] * len(self.layers)
|
||||
#######################################################################################
|
||||
|
||||
# 计算起始位置(用于KV缓存)
|
||||
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
|
||||
|
||||
# 词嵌入 + dropout
|
||||
hidden_states = self.dropout(self.embed_tokens(input_ids))
|
||||
|
||||
# 获取位置编码
|
||||
position_embeddings = (
|
||||
self.freqs_cos[start_pos:start_pos + seq_length],
|
||||
self.freqs_sin[start_pos:start_pos + seq_length]
|
||||
)
|
||||
|
||||
# 逐层前向传播
|
||||
presents = []
|
||||
# Transformer 块堆叠处理
|
||||
for layer_idx, (layer, past_key_value) in enumerate(zip(self.layers, past_key_values)):
|
||||
# 每层 MiniMindBlock 接受如下参数
|
||||
# 并输出 "下一个时刻隐藏状态 hidden_states" 和 "KV 缓存对 present"
|
||||
hidden_states, present = layer(
|
||||
hidden_states,
|
||||
position_embeddings,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
attention_mask=attention_mask
|
||||
hidden_states, # 当前隐藏状态
|
||||
position_embeddings, # 对应的 RoPE 位置编码 (cos, sin)
|
||||
past_key_value=past_key_value, # 上一时刻缓存的 KV 缓存 (key, value) tuple
|
||||
use_cache=use_cache, # 是否启用缓存
|
||||
attention_mask=attention_mask # 注意力掩码
|
||||
)
|
||||
# 存储每层的 kv 缓存
|
||||
presents.append(present)
|
||||
|
||||
# 最终标准化
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# 计算 MoE 辅助损失
|
||||
# 对所有使用了 MOEFeedForward 的层, 累加其 .aux_loss 属性 (属性来自 MOEFeedForward, 值来自 MoEGate)
|
||||
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.Tensor.new_zeros.html
|
||||
# - Tensor.new_zeros(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False)
|
||||
# - 返回一个大小为 size 的、填充有 0 的 Tensor;
|
||||
# - 默认情况下, 返回的 Tensor 具有与此 Tensor 相同的 torch.dtype 和 torch.device
|
||||
aux_loss = sum([l.mlp.aux_loss for l in self.layers if isinstance(l.mlp, MOEFeedForward)], hidden_states.new_zeros(1).squeeze())
|
||||
######################################################################################
|
||||
# 为什么用 sum(..., start) 而不是直接 sum(...)?
|
||||
# - 若无任何 MoE 层, 则列表为空 -> sum([]) 返回 int(0)
|
||||
# - 但模型其他部分使用 torch.Tensor, int 和 Tensor 相加会报错
|
||||
# - 因此必须显式指定一个与 hidden_states 同 device/dtype 的零标量作为初始值
|
||||
######################################################################################
|
||||
# hidden_states.new_zeros(1).squeeze()
|
||||
# - hidden_states.new_zeros(1) 返回一个标量张量 [0.]
|
||||
# - squeeze() 删除维度 1, 变成标量 tensor: tensor(0.)
|
||||
#
|
||||
##### 等效代码 ########################################################################
|
||||
# - zero_scalar = hidden_states.new_zeros(1).squeeze() # 创建一个标量 tensor(0.)
|
||||
# - aux_loss = zero_scalar # aux_loss 初始值为 tensor(0.)
|
||||
# - for layer in self.layers:
|
||||
# - if isinstance(layer.mlp, MOEFeedForward): # 检查是否是 MoE 类型
|
||||
# - aux_loss = aux_loss + layer.mlp.aux_loss # 累加辅助损失
|
||||
######################################################################################
|
||||
|
||||
# 返回最终隐藏状态、缓存和 MoE 辅助损失
|
||||
return hidden_states, presents, aux_loss
|
||||
|
||||
|
||||
class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
|
||||
"""
|
||||
MiniMindForCausalLM 类, 用于构建因果语言模型
|
||||
- 将 MiniMindModel 主干与线性输出头 (lm_head) 组合, 实现完整的文本生成能力
|
||||
- 兼容 HuggingFace Transformers 的训练/推理接口 (如 model.generate())
|
||||
- 并支持自回归解码、KV 缓存、损失计算等标准功能
|
||||
|
||||
# 参考链接
|
||||
https://huggingface.co/docs/transformers/main_classes/text_generation
|
||||
"""
|
||||
config_class = MiniMindConfig
|
||||
|
||||
def __init__(self, config: MiniMindConfig = None):
|
||||
# 初始化配置项,如果未提供则使用默认参数初始化
|
||||
self.config = config or MiniMindConfig()
|
||||
super().__init__(self.config)
|
||||
# MiniMind 模型主体
|
||||
self.model = MiniMindModel(self.config)
|
||||
# 语言模型头 (输出头)
|
||||
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
||||
# 权重共享:词嵌入和语言模型头共享权重
|
||||
# - 减少参数量 ~15~20%(约 vocab_size * hidden_size)
|
||||
# - 提升泛化能力, 增强词向量与输出分布的一致性 (LLaMA、GPT 系列标准做法)
|
||||
# - 此操作需在 super().__init__() 后执行, 因为父类可能初始化参数
|
||||
self.model.embed_tokens.weight = self.lm_head.weight
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**args):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**args
|
||||
):
|
||||
# 主干模型的前向传播
|
||||
hidden_states, past_key_values, aux_loss = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -449,15 +860,29 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
|
||||
use_cache=use_cache,
|
||||
**args
|
||||
)
|
||||
# 计算logits (只保留最后 logits_to_keep 个位置以节省内存)
|
||||
# - 适用于长上下文推理中只关心生成末尾结果的场景
|
||||
# - 若 logits_to_keep = 0, 则保留全部序列
|
||||
# - 若为整数 N, 则取最后 N 个 token 的 logits
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# 标签对齐:采用标准的因果语言建模损失计算方式
|
||||
shift_logits = logits[..., :-1, :].contiguous() # 去掉最后一个位置的预测
|
||||
shift_labels = labels[..., 1:].contiguous() # 去掉第一个位置的标签
|
||||
# 交叉熵损失
|
||||
loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100)
|
||||
|
||||
# 构建输出
|
||||
output = CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
|
||||
output.aux_loss = aux_loss
|
||||
return output
|
||||
# CausalLMOutputWithPast 实例
|
||||
# https://huggingface.co/docs/transformers/v5.0.0/en/main_classes/output#transformers.modeling_outputs.CausalLMOutputWithPast
|
||||
# - loss (torch.Tensor): 训练时计算的交叉熵损失; 推理时为 None
|
||||
# - logits (torch.Tensor): 预测得分, 形状: (batch_size, seq_len, vocab_size) 或截断后长度
|
||||
# - past_key_values (List[Tuple[torch.Tensor]]): 每层的 KV 缓存, 用于后续 token 生成
|
||||
# - hidden_states (torch.Tensor): 所有 Transformer 层输出, 默认不返回, 但继承类提供
|
||||
# - aux_loss (torch.Tensor, 自定义字段): MoE 辅助损失, 仅在启用且训练时非零
|
||||
return output
|
||||
@ -1,12 +1,13 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 设置包名
|
||||
__package__ = "trainer"
|
||||
# 将项目根目录添加到系统路径,确保能够导入同级目录的模块
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import warnings
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
@ -18,52 +19,101 @@ from model.model_minimind import MiniMindConfig
|
||||
from dataset.lm_dataset import SFTDataset
|
||||
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
|
||||
|
||||
# 忽略所有警告信息, 避免训练过程中输出过多警告
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def distillation_loss(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'):
|
||||
"""
|
||||
计算知识蒸馏的 KL 散度损失
|
||||
|
||||
使用 Softmax 和 KL Divergence 衡量学生模型和教师模型预测分布的差异
|
||||
温度参数用于控制分布的平滑程度, 温度越高分布越平滑
|
||||
|
||||
Args:
|
||||
- student_logits: 学生模型的原始输出 logits, shape 为 [batch, seq_len, vocab_size]
|
||||
- teacher_logits: 教师模型的原始输出 logits, shape 为 [batch, seq_len, vocab_size]
|
||||
- temperature: 蒸馏温度系数, 用于平滑 softmax 分布, 默认为 1.0
|
||||
- reduction: 损失聚合方式, 默认为 'batchmean' (按批次平均)
|
||||
|
||||
Returns:
|
||||
- float: 计算得到的蒸馏损失值, 已乘以温度平方
|
||||
"""
|
||||
# 计算教师模型在温度下的概率分布, 并 detach 避免梯度传播到教师模型
|
||||
with torch.no_grad():
|
||||
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1).detach()
|
||||
|
||||
# 计算学生模型在温度下的对数概率分布
|
||||
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
|
||||
|
||||
# 计算 KL 散度损失
|
||||
kl = F.kl_div(
|
||||
student_log_probs,
|
||||
teacher_probs,
|
||||
reduction=reduction
|
||||
)
|
||||
# 乘以温度平方进行缩放, 保持梯度大小与标准交叉熵一致
|
||||
return (temperature ** 2) * kl
|
||||
|
||||
|
||||
def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_step=0, wandb=None, alpha=0.0, temperature=1.0):
|
||||
"""
|
||||
训练一个 epoch
|
||||
|
||||
执行知识蒸馏训练的前向传播、损失计算和反向传播更新
|
||||
包含标准交叉熵损失和蒸馏损失两种监督信号
|
||||
|
||||
Args:
|
||||
- epoch: 当前训练轮次 (从 0 开始计数)
|
||||
- loader: 数据加载器, 提供批次数据
|
||||
- iters: 当前 epoch 的总迭代次数
|
||||
- teacher_model: 教师模型, 用于生成软标签进行蒸馏
|
||||
- lm_config_student: 学生模型配置, 用于判断是否为 MoE 架构
|
||||
- start_step: 起始步数, 用于从检查点恢复训练
|
||||
- wandb: Weights & Biases 日志对象, 用于记录训练指标
|
||||
- alpha: 损失权重, 平衡 CE 损失和蒸馏损失, 总损失 = alpha*CE + (1-alpha)*KL
|
||||
- temperature: 蒸馏温度, 用于控制分布平滑程度
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# 设置教师模型为评估模式, 禁用 dropout 等随机操作
|
||||
if teacher_model is not None:
|
||||
teacher_model.eval()
|
||||
# 冻结教师模型参数, 不计算梯度
|
||||
teacher_model.requires_grad_(False)
|
||||
|
||||
# 遍历数据加载器中的每个批次
|
||||
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||
# 将数据移动到指定设备
|
||||
input_ids = input_ids.to(args.device)
|
||||
labels = labels.to(args.device)
|
||||
# 创建损失掩码, 标记哪些位置需要计算损失 (ignore_index=-100 的位置为 0)
|
||||
loss_mask = (labels[..., 1:] != -100).float()
|
||||
# 使用余弦退火策略计算当前学习率
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
# 更新优化器中所有参数组的学习率
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
# 前向传播(学生模型)
|
||||
# 前向传播 (学生模型)
|
||||
# - 在 __name__ == "__main__" 中定义
|
||||
# - autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(dtype=dtype)
|
||||
with autocast_ctx:
|
||||
res = model(input_ids)
|
||||
# 提取学生模型的 logits, 去掉最后一个位置的预测 (因为标签向右偏移了一位)
|
||||
student_logits = res.logits[..., :-1, :].contiguous()
|
||||
|
||||
# 教师模型前向传播(只在eval & no_grad)
|
||||
# 教师模型前向传播 (只在 eval & no_grad 模式下)
|
||||
if teacher_model is not None:
|
||||
with torch.no_grad():
|
||||
teacher_logits = teacher_model(input_ids).logits[..., :-1, :].contiguous()
|
||||
# 如果学生模型和教师模型的词表大小不同, 只取教师模型前 vocab_size_student 个
|
||||
vocab_size_student = student_logits.size(-1)
|
||||
teacher_logits = teacher_logits[..., :vocab_size_student]
|
||||
|
||||
# ========== 计算损失 ==========
|
||||
# 1) Ground-Truth CE Loss
|
||||
# 计算损失
|
||||
# 1) Ground-Truth CE Loss (标准交叉熵损失)
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
loss_mask_flat = loss_mask.view(-1)
|
||||
ce_loss = F.cross_entropy(
|
||||
@ -72,42 +122,61 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
|
||||
ignore_index=-100,
|
||||
reduction='none'
|
||||
)
|
||||
# 计算加权平均损失, 忽略被 mask 的位置
|
||||
ce_loss_raw = torch.sum(ce_loss * loss_mask_flat) / (loss_mask_flat.sum() + 1e-8)
|
||||
if lm_config_student.use_moe: ce_loss = ce_loss_raw + res.aux_loss
|
||||
else: ce_loss = ce_loss_raw
|
||||
|
||||
# 2) Distillation Loss
|
||||
# 如果是 MoE 模型, 加上辅助损失
|
||||
if lm_config_student.use_moe:
|
||||
ce_loss = ce_loss_raw + res.aux_loss
|
||||
else:
|
||||
ce_loss = ce_loss_raw
|
||||
|
||||
# 2) Distillation Loss (蒸馏损失)
|
||||
if teacher_model is not None:
|
||||
# 只在被 mask 的位置上计算蒸馏损失
|
||||
distill_loss = distillation_loss(
|
||||
student_logits.view(-1, student_logits.size(-1))[loss_mask_flat == 1],
|
||||
teacher_logits.view(-1, teacher_logits.size(-1))[loss_mask_flat == 1],
|
||||
temperature=temperature
|
||||
)
|
||||
else:
|
||||
# 如果没有教师模型, 蒸馏损失为 0
|
||||
distill_loss = torch.tensor(0.0, device=args.device)
|
||||
|
||||
# 3) 总损失 = alpha * CE + (1-alpha) * Distill
|
||||
# 3) 总损失 = alpha * CE + (1-alpha) * Distill,再除以梯度累积步数
|
||||
loss = (alpha * ce_loss + (1 - alpha) * distill_loss) / args.accumulation_steps
|
||||
|
||||
# 反向传播
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# 梯度累积步数达到后, 执行优化器更新
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
# 取消梯度缩放, 准备更新
|
||||
scaler.unscale_(optimizer)
|
||||
# 梯度裁剪, 防止梯度爆炸
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
# 执行优化器更新
|
||||
scaler.step(optimizer)
|
||||
# 更新缩放器
|
||||
scaler.update()
|
||||
# 清空梯度, 将 None 赋值给梯度以节省内存
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# 记录日志
|
||||
if step % args.log_interval == 0 or step == iters - 1:
|
||||
spend_time = time.time() - start_time
|
||||
# 恢复实际的损失值 (乘以累积步数)
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
current_ce_loss = ce_loss_raw.item()
|
||||
current_aux_loss = res.aux_loss.item() if lm_config_student.use_moe else 0.0
|
||||
current_lr = optimizer.param_groups[-1]['lr']
|
||||
# 估算剩余时间
|
||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||
|
||||
# 打印训练日志
|
||||
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, ce: {current_ce_loss:.4f}, aux_loss: {current_aux_loss:.4f}, distill: {distill_loss.item():.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
|
||||
|
||||
# 记录到 wandb
|
||||
if wandb:
|
||||
wandb.log({
|
||||
"loss": current_loss,
|
||||
@ -118,18 +187,26 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
|
||||
"epoch_time": eta_min
|
||||
})
|
||||
|
||||
# 保存检查点
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
model.eval()
|
||||
# 根据是否为 MoE 模型添加后缀
|
||||
moe_suffix = '_moe' if lm_config_student.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config_student.hidden_size}{moe_suffix}.pth'
|
||||
# 获取原始模型 (解包 DDP)
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
# 保存模型权重为半精度格式
|
||||
state_dict = raw_model.state_dict()
|
||||
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||
# 保存完整检查点 (包含优化器状态等)
|
||||
lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
# 恢复训练模式
|
||||
model.train()
|
||||
# 清理内存
|
||||
del state_dict
|
||||
|
||||
# 清理本批次的数据, 释放显存
|
||||
del input_ids, labels, loss_mask, res, student_logits, ce_loss, distill_loss, loss
|
||||
|
||||
|
||||
@ -147,40 +224,42 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
|
||||
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
|
||||
parser.add_argument("--max_seq_len", type=int, default=340, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
||||
parser.add_argument("--max_seq_len", type=int, default=340, help="训练的最大截断长度 (中文1token≈1.5~1.7字符)")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径")
|
||||
parser.add_argument('--student_hidden_size', default=512, type=int, help="学生模型隐藏层维度")
|
||||
parser.add_argument('--student_num_layers', default=8, type=int, help="学生模型隐藏层数量")
|
||||
parser.add_argument('--teacher_hidden_size', default=768, type=int, help="教师模型隐藏层维度")
|
||||
parser.add_argument('--teacher_num_layers', default=16, type=int, help="教师模型隐藏层数量")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构 (0=否, 1=是)")
|
||||
parser.add_argument('--from_student_weight', default='full_sft', type=str, help="学生模型基于哪个权重")
|
||||
parser.add_argument('--from_teacher_weight', default='full_sft', type=str, help="教师模型基于哪个权重")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||
parser.add_argument('--alpha', default=0.5, type=float, help="CE损失权重,总损失=alpha*CE+(1-alpha)*KL")
|
||||
parser.add_argument('--temperature', default=1.5, type=float, help="蒸馏温度(推荐范围1.0-2.0)")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训 (0=否, 1=是)")
|
||||
parser.add_argument('--alpha', default=0.5, type=float, help="CE损失权重, 总损失=alpha*CE+(1-alpha)*KL")
|
||||
parser.add_argument('--temperature', default=1.5, type=float, help="蒸馏温度 (推荐范围1.0-2.0)")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Distillation", help="wandb项目名")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速 (0=否, 1=是)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
# ========== 1.初始化环境和随机种子 ==========
|
||||
local_rank = init_distributed_mode()
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
# ========== 2.配置目录、模型参数、检查ckp ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
lm_config_student = MiniMindConfig(hidden_size=args.student_hidden_size, num_hidden_layers=args.student_num_layers, use_moe=bool(args.use_moe))
|
||||
lm_config_teacher = MiniMindConfig(hidden_size=args.teacher_hidden_size, num_hidden_layers=args.teacher_num_layers, use_moe=bool(args.use_moe))
|
||||
ckp_data = lm_checkpoint(lm_config_student, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
# ========== 3.配置混合精度训练 ==========
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
# https://docs.pytorch.ac.cn/docs/stable/amp
|
||||
# 新版废弃: autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
# ========== 4.配置 wandb ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
@ -189,22 +268,32 @@ if __name__ == "__main__":
|
||||
wandb_run_name = f"MiniMind-Distill-S{args.student_hidden_size}T{args.teacher_hidden_size}-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义学生和教师模型 ==========
|
||||
# ========== 5.定义学生和教师模型 ==========
|
||||
# 学生模型和分词器
|
||||
model, tokenizer = init_model(lm_config_student, args.from_student_weight, device=args.device)
|
||||
# 使用 torch.compile 加速模型 (如果支持)
|
||||
if args.use_compile == 1:
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
Logger(f'学生模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
|
||||
# 教师模型
|
||||
teacher_model, _ = init_model(lm_config_teacher, args.from_teacher_weight, device=args.device)
|
||||
teacher_model.eval()
|
||||
teacher_model.requires_grad_(False)
|
||||
Logger(f'教师模型总参数量:{sum(p.numel() for p in teacher_model.parameters()) / 1e6:.3f} M')
|
||||
|
||||
# 创建训练数据集
|
||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
# 创建分布式采样器 (如果在分布式环境下)
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
# 创建梯度缩放器 (混合精度训练需要)
|
||||
# https://docs.pytorch.ac.cn/docs/stable/amp.html#gradient-scaling
|
||||
# 新版废弃: scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
scaler = torch.amp.GradScaler("cuda", enabled=(args.dtype == 'float16'))
|
||||
# 创建优化器
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
# ========== 6.从检查点恢复训练状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
@ -213,23 +302,29 @@ if __name__ == "__main__":
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 7. DDP包模型 ==========
|
||||
# ========== 7.使用 DDP 包装模型 (如果在分布式环境下) ==========
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
# ========== 8.开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
# 设置分布式采样器的 epoch (打乱数据)
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
setup_seed(42 + epoch)
|
||||
# 随机打乱数据集索引
|
||||
indices = torch.randperm(len(train_ds)).tolist()
|
||||
# 计算需要跳过的批次数量
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
# 创建批次采样器
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
# 如果有跳过的批次, 打印提示信息并恢复训练
|
||||
if skip > 0:
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
train_epoch(epoch, loader, len(loader) + skip, teacher_model, lm_config_student, start_step, wandb, args.alpha, args.temperature)
|
||||
else:
|
||||
train_epoch(epoch, loader, len(loader), teacher_model, lm_config_student, 0, wandb, args.alpha, args.temperature)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
# ========== 9.清理分布进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -1,12 +1,13 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 设置包名
|
||||
__package__ = "trainer"
|
||||
# 将项目根目录添加到系统路径, 确保能够导入同级目录的模块
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import warnings
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
@ -18,10 +19,25 @@ from model.model_minimind import MiniMindConfig
|
||||
from dataset.lm_dataset import DPODataset
|
||||
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
|
||||
|
||||
# 忽略所有警告信息, 保持输出整洁
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def logits_to_log_probs(logits, labels):
|
||||
"""
|
||||
将 logits 转换为对数概率
|
||||
|
||||
对每个位置的 logits 计算 log_softmax, 然后根据 labels 提取对应的对数概率
|
||||
这是 DPO 训练中的基础操作, 用于计算策略模型和参考模型的 log probabilities
|
||||
|
||||
Args:
|
||||
- logits: 模型输出的 logits, shape: (batch_size, seq_len, vocab_size)
|
||||
- labels: 目标 token 序列, shape: (batch_size, seq_len)
|
||||
|
||||
Returns:
|
||||
- log_probs_per_token: 每个 token 的对数概率, shape: (batch_size, seq_len)
|
||||
"""
|
||||
# logits shape: (batch_size, seq_len, vocab_size)
|
||||
# labels shape: (batch_size, seq_len)
|
||||
# log_probs shape: (batch_size, seq_len)
|
||||
@ -31,19 +47,36 @@ def logits_to_log_probs(logits, labels):
|
||||
|
||||
|
||||
def dpo_loss(ref_log_probs, policy_log_probs, mask, beta):
|
||||
"""
|
||||
计算 DPO (Direct Preference Optimization) 损失
|
||||
|
||||
DPO 是一种直接根据人类偏好数据优化语言模型的方法, 不需要训练奖励模型
|
||||
通过比较策略模型和参考模型在 chosen/rejected 样本上的对数概率比率来计算损失
|
||||
|
||||
Args:
|
||||
- ref_log_probs: 参考模型对数概率, shape: (batch_size, seq_len)
|
||||
- policy_log_probs: 策略模型对数概率, shape: (batch_size, seq_len)
|
||||
- mask: 有效 token 掩码, shape: (batch_size, seq_len)
|
||||
- beta: DPO 温度参数, 控制模型偏离参考模型的程度
|
||||
|
||||
Returns:
|
||||
- loss: 平均 DPO 损失值
|
||||
"""
|
||||
# ref_log_probs 和 policy_log_probs 都是 shape: (batch_size, seq_len)
|
||||
# https://github.com/jingyaogong/minimind/issues/298
|
||||
seq_lengths = mask.sum(dim=1, keepdim=True).clamp_min(1e-8) # 防止零长度mask导致除零NaN
|
||||
# 计算每个序列的有效长度, 用于对 log probs 进行归一化
|
||||
seq_lengths = mask.sum(dim=1, keepdim=True).clamp_min(1e-8) # 防止零长度 mask 导致除零 NaN
|
||||
ref_log_probs = (ref_log_probs * mask).sum(dim=1) / seq_lengths.squeeze()
|
||||
policy_log_probs = (policy_log_probs * mask).sum(dim=1) / seq_lengths.squeeze()
|
||||
|
||||
# 将 chosen 和 rejected 数据分开
|
||||
# 将 chosen 和 rejected 数据分开, 批次的前半部分是 chosen, 后半部分是 rejected
|
||||
batch_size = ref_log_probs.shape[0]
|
||||
chosen_ref_log_probs = ref_log_probs[:batch_size // 2]
|
||||
reject_ref_log_probs = ref_log_probs[batch_size // 2:]
|
||||
chosen_policy_log_probs = policy_log_probs[:batch_size // 2]
|
||||
reject_policy_log_probs = policy_log_probs[batch_size // 2:]
|
||||
|
||||
# 计算对数概率比率, 衡量策略模型相比参考模型的改进程度
|
||||
pi_logratios = chosen_policy_log_probs - reject_policy_log_probs
|
||||
ref_logratios = chosen_ref_log_probs - reject_ref_log_probs
|
||||
logits = pi_logratios - ref_logratios
|
||||
@ -52,39 +85,62 @@ def dpo_loss(ref_log_probs, policy_log_probs, mask, beta):
|
||||
|
||||
|
||||
def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=None, beta=0.1):
|
||||
"""
|
||||
执行一个 epoch 的 DPO 训练
|
||||
|
||||
Args:
|
||||
- epoch: 当前训练轮次
|
||||
- loader: 数据加载器
|
||||
- iters: 当前 epoch 的总迭代次数
|
||||
- ref_model: 参考模型 (冻结参数, 不更新)
|
||||
- lm_config: 模型配置对象
|
||||
- start_step: 从此步骤开始训练 (用于恢复训练)
|
||||
- wandb: Weights & Biases 日志对象
|
||||
- beta: DPO 温度参数
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
for step, batch in enumerate(loader, start=start_step + 1):
|
||||
# 从批次中提取 chosen 和 rejected 数据, 并移动到目标设备
|
||||
x_chosen = batch['x_chosen'].to(args.device)
|
||||
x_rejected = batch['x_rejected'].to(args.device)
|
||||
y_chosen = batch['y_chosen'].to(args.device)
|
||||
y_rejected = batch['y_rejected'].to(args.device)
|
||||
mask_chosen = batch['mask_chosen'].to(args.device)
|
||||
mask_rejected = batch['mask_rejected'].to(args.device)
|
||||
|
||||
# 将 chosen 和 rejected 数据拼接在一起, 形成完整的批次
|
||||
# 批次的前半部分是 chosen (用户偏好的回复), 后半部分是 rejected (用户不喜欢的回复)
|
||||
x = torch.cat([x_chosen, x_rejected], dim=0)
|
||||
y = torch.cat([y_chosen, y_rejected], dim=0)
|
||||
mask = torch.cat([mask_chosen, mask_rejected], dim=0)
|
||||
|
||||
# 根据当前步骤动态计算学习率 (余弦退火策略)
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
with autocast_ctx:
|
||||
# 参考模型前向传播, 不计算梯度 (frozen)
|
||||
with torch.no_grad():
|
||||
ref_outputs = ref_model(x)
|
||||
ref_logits = ref_outputs.logits
|
||||
ref_log_probs = logits_to_log_probs(ref_logits, y)
|
||||
|
||||
# 策略模型前向传播, 计算 logits 和 log probabilities
|
||||
outputs = model(x)
|
||||
logits = outputs.logits
|
||||
policy_log_probs = logits_to_log_probs(logits, y)
|
||||
|
||||
# 计算 DPO 损失和辅助损失 (MoE 负载均衡损失)
|
||||
dpo_loss_val = dpo_loss(ref_log_probs, policy_log_probs, mask, beta=beta)
|
||||
loss = dpo_loss_val + outputs.aux_loss
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
# 使用梯度缩放器进行反向传播, 支持混合精度训练
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# 当累积达到指定步数时, 执行参数更新
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
@ -92,6 +148,7 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# 按指定间隔打印日志
|
||||
if step % args.log_interval == 0 or step == iters - 1:
|
||||
spend_time = time.time() - start_time
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
@ -104,10 +161,12 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
|
||||
|
||||
if wandb: wandb.log({"loss": current_loss, "dpo_loss": current_dpo_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||
|
||||
# 按指定间隔保存模型检查点, 仅主进程执行
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
model.eval()
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
# 解包 DDP 包装的模型以获取原始模型
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
state_dict = raw_model.state_dict()
|
||||
@ -116,17 +175,19 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
# 删除中间变量以释放内存
|
||||
del x_chosen, x_rejected, y_chosen, y_rejected, mask_chosen, mask_rejected, x, y, mask
|
||||
del ref_outputs, ref_logits, ref_log_probs, outputs, logits, policy_log_probs, loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建参数解析器, 定义 DPO 训练的所有配置参数
|
||||
parser = argparse.ArgumentParser(description="MiniMind DPO (Direct Preference Optimization)")
|
||||
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||
parser.add_argument('--save_weight', default='dpo', type=str, help="保存权重的前缀名")
|
||||
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
|
||||
parser.add_argument("--batch_size", type=int, default=4, help="batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=4e-8, help="初始学习率(建议<=5e-8避免遗忘)")
|
||||
parser.add_argument("--learning_rate", type=float, default=4e-8, help="初始学习率 (建议<=5e-8避免遗忘)")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
|
||||
@ -136,84 +197,113 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||
parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度 (中文1token≈1.5~1.7字符)")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构 (0=否, 1=是)")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/dpo.jsonl", help="DPO训练数据路径")
|
||||
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训 (0=否, 1=是)")
|
||||
parser.add_argument('--beta', default=0.1, type=float, help="DPO中的beta参数")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-DPO", help="wandb项目名")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速 (0=否, 1=是)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
# ========== 1.初始化环境和随机种子 ==========
|
||||
# 初始化分布式训练模式, 返回本地 rank (非 DDP 模式下返回 0)
|
||||
local_rank = init_distributed_mode()
|
||||
# 如果处于分布式训练环境, 根据 local_rank 设置当前设备
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
# 设置随机种子, 分布式训练时根据 rank 偏移种子以确保不同进程使用不同的数据顺序
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
# ========== 2.配置目录、模型参数、检查检查点 ==========
|
||||
# 创建保存目录 (如果不存在)
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
# 创建 MiniMind 模型配置
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||
# 如果启用续训模式, 尝试从检查点恢复训练状态
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
# ========== 3.设置混合精度训练 ==========
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
# 根据配置选择数据类型 (bfloat16 或 float16)
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
# https://docs.pytorch.ac.cn/docs/stable/amp
|
||||
# 新版废弃: autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
# ========== 4.配置 wandb 日志 ==========
|
||||
wandb = None
|
||||
# 仅在主进程中初始化 wandb, 避免重复记录
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
# 从检查点恢复 wandb 运行 ID (如果存在)
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
# 构建运行名称, 包含关键训练参数
|
||||
wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义模型和参考模型 ==========
|
||||
# ========== 5.初始化模型和参考模型 ==========
|
||||
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||
# 如果启用 torch.compile, 使用 PyTorch 2.0 的编译加速功能
|
||||
if args.use_compile == 1:
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
Logger(f'策略模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
|
||||
# 初始化参考模型(ref_model冻结)
|
||||
Logger(f'策略模型总参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
|
||||
# 初始化参考模型 (ref_model 冻结, 不参与训练, 仅用于计算参考 logits)
|
||||
ref_model, _ = init_model(lm_config, args.from_weight, device=args.device)
|
||||
ref_model.eval()
|
||||
ref_model.requires_grad_(False)
|
||||
Logger(f'参考模型总参数量:{sum(p.numel() for p in ref_model.parameters()) / 1e6:.3f} M')
|
||||
Logger(f'参考模型总参数量: {sum(p.numel() for p in ref_model.parameters()) / 1e6:.3f} M')
|
||||
|
||||
# 创建 DPO 数据集和数据采样器
|
||||
train_ds = DPODataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
# 分布式训练时使用 DistributedSampler 确保数据分配均匀
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
# 创建梯度缩放器, 仅在 float16 模式下启用 (用于混合精度训练)
|
||||
# https://docs.pytorch.ac.cn/docs/stable/amp.html#gradient-scaling
|
||||
# 新版废弃: scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
scaler = torch.amp.GradScaler("cuda", enabled=(args.dtype == 'float16'))
|
||||
# 创建 AdamW 优化器
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
# ========== 6.从检查点恢复训练状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
# 从检查点加载模型权重、优化器状态、梯度缩放器状态
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scaler.load_state_dict(ckp_data['scaler'])
|
||||
# 恢复训练进度
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 7. DDP包模型 ==========
|
||||
# ========== 7.使用 DDP 包装模型 (分布式训练) ==========
|
||||
if dist.is_initialized():
|
||||
# 忽略位置编码相关的参数 (不需要梯度同步)
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
# 使用 DistributedDataParallel 包装模型, 启用分布式训练
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
# ========== 8.开始训练循环 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
# 分布式训练时设置 epoch 以确保不同进程使用不同的数据顺序
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
# 设置当前 epoch 的随机种子, 生成随机索引序列
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
# 如果是恢复训练的第一个 epoch, 计算需要跳过的步数
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
# 创建支持跳过功能的批次采样器
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
# 创建数据加载器
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
# 执行当前 epoch 的训练
|
||||
if skip > 0:
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个 step, 从 step {start_step + 1}开始')
|
||||
train_epoch(epoch, loader, len(loader) + skip, ref_model, lm_config, start_step, wandb, args.beta)
|
||||
else:
|
||||
train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
# ========== 9.清理分布式进程组 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -1,12 +1,13 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 设置包名
|
||||
__package__ = "trainer"
|
||||
# 将项目根目录添加到系统路径, 确保能够导入同级目录的模块
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import warnings
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
@ -17,57 +18,112 @@ from model.model_minimind import MiniMindConfig
|
||||
from dataset.lm_dataset import SFTDataset
|
||||
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
|
||||
|
||||
# 忽略所有警告信息, 避免训练过程中输出无关的警告日志
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
"""
|
||||
单个 epoch 的训练函数
|
||||
|
||||
执行一个完整的训练周期, 包括前向传播、反向传播、梯度更新、日志打印和模型保存
|
||||
|
||||
Args:
|
||||
- epoch: 当前训练轮次, 从 0 开始计数
|
||||
- loader: 数据加载器, 提供训练数据批次
|
||||
- iters: 当前 epoch 的总迭代次数
|
||||
- start_step: 起始步数, 用于恢复训练时跳过已训练的批次, 默认为 0
|
||||
- wandb: Weights & Biases 日志对象, 为 None 时不记录日志
|
||||
"""
|
||||
# 记录 epoch 开始时间, 用于计算 ETA (预计剩余时间)
|
||||
start_time = time.time()
|
||||
|
||||
# 遍历数据加载器中的所有批次
|
||||
# enumerate 从 start_step + 1 开始, 保证恢复训练时 step 编号连续
|
||||
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||
# 将输入数据和标签移动到训练设备 (GPU 或 CPU)
|
||||
input_ids = input_ids.to(args.device)
|
||||
labels = labels.to(args.device)
|
||||
|
||||
# 根据余弦退火策略计算当前步的学习率
|
||||
# 公式参数: 当前全局步数 (epoch * iters + step), 总步数 (epochs * iters), 基础学习率
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
# 更新优化器中所有参数组的学习率
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
# 使用混合精度上下文进行前向传播
|
||||
# 在 CUDA 设备上使用 bfloat16/float16, CPU 设备上不使用混合精度
|
||||
with autocast_ctx:
|
||||
# 模型前向传播, 输入 input_ids 和 labels, 返回 loss 和 aux_loss
|
||||
res = model(input_ids, labels=labels)
|
||||
# 总损失 = 主损失 (交叉熵) + 辅助损失 (MoE 负载均衡损失)
|
||||
loss = res.loss + res.aux_loss
|
||||
# 梯度累积: 将损失除以累积步数, 模拟更大 batch size 的训练效果
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
# 使用 GradScaler 进行反向传播, 支持混合精度训练的梯度缩放
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# 每 accumulation_steps 步执行一次梯度更新
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
# 反缩放梯度, 将缩放后的梯度还原为真实梯度值
|
||||
scaler.unscale_(optimizer)
|
||||
# 梯度裁剪, 防止梯度爆炸, 限制梯度的 L2 范数不超过 grad_clip
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
|
||||
# 执行优化器更新步骤
|
||||
scaler.step(optimizer)
|
||||
# 更新 GradScaler 的缩放因子, 根据梯度是否溢出动态调整
|
||||
scaler.update()
|
||||
|
||||
# 清空梯度, set_to_none=True 可以节省内存
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# 每 log_interval 步打印一次日志, 或在该 epoch 最后一步打印
|
||||
if step % args.log_interval == 0 or step == iters - 1:
|
||||
# 计算已花费的时间
|
||||
spend_time = time.time() - start_time
|
||||
# 还原真实的损失值 (之前被除以 accumulation_steps)
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
# 获取 MoE 辅助损失, 如果不存在则为 0
|
||||
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
|
||||
# 计算主损失 (总损失减去辅助损失)
|
||||
current_logits_loss = current_loss - current_aux_loss
|
||||
# 获取当前学习率
|
||||
current_lr = optimizer.param_groups[-1]['lr']
|
||||
# 计算 ETA (预计剩余时间), 单位为分钟
|
||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||
# 打印训练日志
|
||||
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
|
||||
# 如果启用了 wandb, 记录指标
|
||||
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||
|
||||
# 每 save_interval 步保存一次模型, 或在 epoch 最后一步保存, 仅在主进程中执行
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
# 切换到评估模式, 确保保存时模型处于稳定状态
|
||||
model.eval()
|
||||
# 根据是否使用 MoE 构建模型权重文件名后缀
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
# 构建权重保存路径
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
# 如果模型被 DDP 包装, 获取原始模型
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
# 如果模型被 torch.compile 编译, 获取原始模型
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
# 获取模型状态字典
|
||||
state_dict = raw_model.state_dict()
|
||||
# 将权重转换为半精度 (FP16) 并保存到 CPU, 然后保存到磁盘
|
||||
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
|
||||
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler)
|
||||
# 保存完整的检查点 (包含优化器状态等), 用于恢复训练
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler)
|
||||
# 切换回训练模式
|
||||
model.train()
|
||||
# 删除状态字典以释放内存
|
||||
del state_dict
|
||||
|
||||
# 删除张量以释放 GPU 内存
|
||||
del input_ids, labels, res, loss
|
||||
|
||||
|
||||
@ -87,76 +143,112 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度 (中文 1 token ≈ 1.5~1.7 字符)")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否, 1=是)")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径")
|
||||
parser.add_argument('--from_weight', default='pretrain', type=str, help="基于哪个权重训练,为none则不基于任何权重训练")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||
parser.add_argument('--from_weight', default='pretrain', type=str, help="基于哪个权重训练, 为 none 则不基于任何权重训练")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否, 1=是)")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT", help="wandb项目名")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否, 1=是)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
# ========== 1.初始化环境和随机种子 ==========
|
||||
# 初始化分布式训练模式, 如果是分布式训练则返回 local_rank, 否则返回 0
|
||||
local_rank = init_distributed_mode()
|
||||
# 如果分布式环境已初始化, 根据 local_rank 设置当前进程使用的 GPU 设备
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
# 设置随机种子, 确保实验可复现, 不同 rank 使用不同种子避免数据同步问题
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
|
||||
# ========== 2.配置目录、模型参数、检查ckp ==========
|
||||
# 创建模型保存目录, 如果目录已存在则不报错
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
# 创建模型配置对象, 包含隐藏层维度、层数、是否使用 MoE 等参数
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||
# 如果启用自动恢复训练, 尝试从检查点加载数据, 否则返回 None
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
|
||||
# ========== 3.设置混合精度 ==========
|
||||
# 根据设备类型确定是 cuda 还是 cpu
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
# 根据配置选择混合精度的数据类型: bfloat16 或 float16
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
# https://docs.pytorch.ac.cn/docs/stable/amp
|
||||
# 新版废弃: autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=dtype)
|
||||
|
||||
# ========== 4.配置 wandb ==========
|
||||
wandb = None
|
||||
# 仅在主进程中初始化 Weights & Biases 日志记录
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
# 如果有检查点数据, 获取之前的 wandb 运行 ID 用于恢复日志记录
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
# 如果有 wandb_id, 设置为必须恢复模式; 否则创建新运行
|
||||
resume = 'must' if wandb_id else None
|
||||
# 构建运行名称, 包含训练配置参数
|
||||
wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
# 初始化 wandb 运行
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义模型、数据、优化器 ==========
|
||||
|
||||
# ========== 5.定义模型、数据、优化器 ==========
|
||||
# 初始化模型和分词器, 可以选择从预训练权重加载
|
||||
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||
# 如果启用 torch.compile, 编译模型以加速训练
|
||||
if args.use_compile == 1:
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
# 创建 SFT 数据集, 用于监督微调
|
||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
# 如果启用分布式训练, 创建 DistributedSampler; 否则使用 None (后续会用普通索引)
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
# 创建梯度缩放器, 仅在 float16 模式下启用 (用于混合精度训练)
|
||||
# https://docs.pytorch.ac.cn/docs/stable/amp.html#gradient-scaling
|
||||
# 新版废弃: scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
scaler = torch.amp.GradScaler("cuda", enabled=(args.dtype == 'float16'))
|
||||
# 创建 AdamW 优化器
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
|
||||
# ========== 6.从 ckp 恢复状态 ==========
|
||||
# 初始化起始轮次和起始步数
|
||||
start_epoch, start_step = 0, 0
|
||||
# 如果有检查点数据, 恢复模型、优化器、scaler 的状态, 以及训练进度
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scaler.load_state_dict(ckp_data['scaler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 7. DDP包模型 ==========
|
||||
|
||||
# ========== 7.DDP 包装模型 ==========
|
||||
# 如果启用分布式训练, 使用 DistributedDataParallel 包装模型
|
||||
if dist.is_initialized():
|
||||
# 设置 DDP 忽略同步的参数: RoPE 的位置编码不需要梯度同步
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
# 使用 DDP 包装模型, 指定当前进程使用的 GPU 设备
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
|
||||
# ========== 8.开始训练 ==========
|
||||
# 从 start_epoch 开始训练, 直到 args.epochs
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
# 设置 DistributedSampler 的 epoch, 确保每个 epoch 的数据打乱方式不同
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
# 为每个 epoch 设置不同的随机种子, 确保数据打乱顺序不同
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
# 如果是恢复训练的第一个 epoch, 需要跳过已经训练过的 step; 其他 epoch 从 0 开始
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
# 创建 SkipBatchSampler, 用于跳过指定数量的批次
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
# 创建 DataLoader, 使用自定义的 batch_sampler
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
# 如果需要跳过批次, 打印提示信息并调用 train_epoch 时传入正确的总步数和起始步数
|
||||
if skip > 0:
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step, 从step {start_step + 1}开始')
|
||||
train_epoch(epoch, loader, len(loader) + skip, start_step, wandb)
|
||||
else:
|
||||
train_epoch(epoch, loader, len(loader), 0, wandb)
|
||||
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
# 训练结束后, 如果启用了分布式训练, 销毁进程组
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -1,13 +1,14 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 设置包名
|
||||
__package__ = "trainer"
|
||||
# 将项目根目录添加到系统路径, 确保能够导入同级目录的模块
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import gc
|
||||
import warnings
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers import AutoTokenizer
|
||||
@ -21,17 +22,43 @@ from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from dataset.lm_dataset import RLAIFDataset
|
||||
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model
|
||||
|
||||
# 忽略警告信息
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
"""整合所有奖励函数计算总奖励"""
|
||||
"""
|
||||
整合所有奖励函数计算总奖励
|
||||
|
||||
根据配置使用 reasoning 模型奖励函数和/或 reward model 评分来计算总奖励
|
||||
支持推理模型的格式检查和普通模型的 reward model 评分
|
||||
|
||||
Args:
|
||||
- prompts: 输入 prompt 列表, list[str], 长度为 batch_size
|
||||
- responses: 模型生成的回复列表, list[str], 长度为 batch_size * num_generations
|
||||
- reward_model: reward model 实例, 用于计算语义相似度分数
|
||||
- reward_tokenizer: reward model 的分词器
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: 每个生成样本的奖励分数, shape [B*num_gen]
|
||||
"""
|
||||
def reasoning_model_reward(rewards):
|
||||
"""
|
||||
如果使用 reasoning 模型, 应用 reasoning 奖励函数
|
||||
|
||||
if args.reasoning == 1:
|
||||
rewards = reasoning_model_reward(rewards)
|
||||
"""
|
||||
# 定义推理模型的格式匹配模式: <think> 块后跟 <answer> 块
|
||||
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
|
||||
# 备选模式: 允许 <think> 和 <answer> 之间有一个空行
|
||||
pattern2 = r"^<think>\n.*?\n</think>\n\n<answer>\n.*?\n</answer>$"
|
||||
# 检查每个回复是否匹配格式
|
||||
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
|
||||
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
|
||||
|
||||
# 计算格式奖励: 格式正确得 0.5 分, 否则 0 分
|
||||
format_rewards = []
|
||||
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
|
||||
if match_pattern or match_pattern2:
|
||||
@ -41,6 +68,7 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
rewards += torch.tensor(format_rewards, device=args.device)
|
||||
|
||||
def mark_num(text):
|
||||
# 统计标签出现次数, 每个正确出现的标签奖励 0.25 分
|
||||
reward = 0
|
||||
if text.count("<think>") == 1: reward += 0.25
|
||||
if text.count("</think>") == 1: reward += 0.25
|
||||
@ -48,33 +76,42 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
if text.count("</answer>") == 1: reward += 0.25
|
||||
return reward
|
||||
|
||||
# 计算标签数量奖励
|
||||
mark_rewards = [mark_num(response) for response in responses]
|
||||
rewards += torch.tensor(mark_rewards, device=args.device)
|
||||
return rewards
|
||||
|
||||
# 初始化奖励张量为 0
|
||||
rewards = torch.zeros(len(responses), device=args.device)
|
||||
# 如果使用 reasoning 模型, 应用 reasoning 奖励函数
|
||||
if args.reasoning == 1:
|
||||
rewards = reasoning_model_reward(rewards)
|
||||
|
||||
# 使用 reward model 计算语义奖励
|
||||
with torch.no_grad():
|
||||
reward_model_scores = []
|
||||
batch_size = len(prompts)
|
||||
scale = 3.0
|
||||
|
||||
# 遍历每个 prompt 和每个生成的回复
|
||||
for i in range(batch_size):
|
||||
for j in range(args.num_generations):
|
||||
response_idx = i * args.num_generations + j
|
||||
response = responses[response_idx]
|
||||
prompt = prompts[i]
|
||||
|
||||
# 解析 prompt 中的对话历史
|
||||
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
|
||||
matches = re.findall(pattern, prompt, re.DOTALL)
|
||||
messages = [{"role": role, "content": content.strip()} for role, content in matches]
|
||||
|
||||
# 构建完整的对话上下文并计算 reward model 分数
|
||||
tmp_chat = messages + [{"role": "assistant", "content": response}]
|
||||
score = reward_model.get_score(reward_tokenizer, tmp_chat)
|
||||
# 将分数裁剪到 [-scale, scale] 范围
|
||||
score = max(min(score, scale), -scale)
|
||||
|
||||
# 对于 reasoning 模型, 额外计算 answer 部分的分数
|
||||
if args.reasoning == 1:
|
||||
answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
|
||||
if answer_match:
|
||||
@ -82,10 +119,12 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
|
||||
answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
|
||||
answer_score = max(min(answer_score, scale), -scale)
|
||||
# 综合分数: 40% 完整回复 + 60% answer 部分
|
||||
score = score * 0.4 + answer_score * 0.6
|
||||
|
||||
reward_model_scores.append(score)
|
||||
|
||||
# 将 reward model 分数添加到总奖励
|
||||
reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
|
||||
rewards += reward_model_scores
|
||||
|
||||
@ -93,24 +132,58 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
|
||||
|
||||
def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, start_step=0, wandb=None):
|
||||
"""
|
||||
GRPO (Group Relative Policy Optimization) 单 epoch 训练函数
|
||||
|
||||
实现 GRPO 算法的核心训练循环, 包括:
|
||||
- 生成多组回复样本
|
||||
- 计算 reward 和 advantage
|
||||
- 计算 policy loss 和 KL divergence
|
||||
- 更新模型参数
|
||||
|
||||
Args:
|
||||
- epoch: 当前训练轮次
|
||||
- loader: 数据加载器
|
||||
- iters: 总迭代次数
|
||||
- ref_model: reference model (冻结参数), 用于计算 KL divergence
|
||||
- reward_model: reward model, 用于计算奖励分数
|
||||
- reward_tokenizer: reward model 的分词器
|
||||
- start_step: 从第几步开始 (用于断点续训)
|
||||
- wandb: 日志记录对象
|
||||
"""
|
||||
for step, batch in enumerate(loader, start=start_step + 1):
|
||||
# 获取 prompt 列表并编码
|
||||
prompts = batch['prompt'] # list[str], length B
|
||||
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
|
||||
padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
|
||||
prompt_inputs = tokenizer(
|
||||
prompts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_token_type_ids=False,
|
||||
padding_side="left",
|
||||
add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
|
||||
# 截断过长的 prompt
|
||||
if args.max_seq_len:
|
||||
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
|
||||
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
|
||||
|
||||
# 使用当前 model 生成多个回复
|
||||
with torch.no_grad():
|
||||
# DDP 模型需要使用 .module 访问 generate 方法
|
||||
model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
outputs = model_for_gen.generate(
|
||||
**prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
|
||||
num_return_sequences=args.num_generations, pad_token_id=tokenizer.pad_token_id) # [B*num_gen, P+R]
|
||||
**prompt_inputs,
|
||||
max_new_tokens=args.max_gen_len,
|
||||
do_sample=True, temperature=0.8,
|
||||
num_return_sequences=args.num_generations,
|
||||
pad_token_id=tokenizer.pad_token_id) # [B*num_gen, P+R]
|
||||
|
||||
# 提取生成的回复部分 (去掉 prompt)
|
||||
completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] # [B*num_gen, R]
|
||||
|
||||
def get_per_token_logps(mdl, input_ids, n_keep):
|
||||
"""
|
||||
获取每个 token 的 log probability
|
||||
"""
|
||||
input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
|
||||
logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]
|
||||
per_token_logps = []
|
||||
@ -119,35 +192,44 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token
|
||||
per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
|
||||
return torch.stack(per_token_logps)
|
||||
|
||||
# 计算 policy model 的 per-token log probabilities
|
||||
with autocast_ctx:
|
||||
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B*num_gen, R]
|
||||
# 如果使用 MoE, 计算辅助损失
|
||||
res = model(outputs) if lm_config.use_moe else None
|
||||
aux_loss = res.aux_loss if res is not None else torch.tensor(0.0, device=args.device)
|
||||
|
||||
# 计算 reference model 的 per-token log probabilities (无梯度)
|
||||
with torch.no_grad():
|
||||
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B*num_gen, R]
|
||||
|
||||
# 解码生成的回复并计算奖励
|
||||
completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
|
||||
rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device) # [B*num_gen]
|
||||
|
||||
grouped_rewards = rewards.view(-1, args.num_generations) # [B, num_gen]
|
||||
mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
|
||||
std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
|
||||
# 按 group 分组计算 advantage (GRPO 核心)
|
||||
grouped_rewards = rewards.view(-1, args.num_generations) # [B, num_gen]
|
||||
mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
|
||||
std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
|
||||
advantages = torch.clamp((rewards - mean_r) / (std_r + 1e-4), -10, 10)
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # [B*num_gen]
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # [B*num_gen]
|
||||
|
||||
is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R]
|
||||
# 构建 completion mask (处理 EOS 提前结束的情况)
|
||||
is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R]
|
||||
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device)
|
||||
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
||||
completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int() # [B*num_gen, R]
|
||||
|
||||
# 计算 KL divergence
|
||||
kl_div = ref_per_token_logps - per_token_logps
|
||||
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R]
|
||||
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R]
|
||||
# 计算 GRPO loss: advantage-weighted loss - beta * KL penalty
|
||||
per_token_loss = -(torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) - args.beta * per_token_kl) # [B*num_gen, R]
|
||||
policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
||||
loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar
|
||||
loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar
|
||||
loss.backward()
|
||||
|
||||
# 梯度累积更新
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
if args.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
@ -155,6 +237,7 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# 日志记录
|
||||
if step % args.log_interval == 0 or step == iters:
|
||||
policy_loss_val = loss.item() * args.accumulation_steps
|
||||
current_aux_loss = aux_loss.item()
|
||||
@ -176,19 +259,24 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token
|
||||
"learning_rate": current_lr
|
||||
})
|
||||
|
||||
# 模型保存检查点
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
model.eval()
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
# 解包 DDP 或 torch.compile 包装的模型
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
state_dict = raw_model.state_dict()
|
||||
# 保存半精度权重到 CPU
|
||||
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||
# 保存完整的检查点 (用于断点续训)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
|
||||
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
# 清理中间变量释放内存
|
||||
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
|
||||
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
|
||||
|
||||
@ -209,37 +297,47 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构 (0=否, 1=是)")
|
||||
parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
|
||||
parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
|
||||
parser.add_argument("--num_generations", type=int, default=8, help="每个prompt生成的样本数")
|
||||
parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数")
|
||||
parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型(0=普通模型,1=推理模型)')
|
||||
parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型 (0=普通模型, 1=推理模型)')
|
||||
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训 (0=否, 1=是)")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-GRPO", help="wandb项目名")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速 (0=否, 1=是)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
# ========== 1.初始化环境和随机种子 ==========
|
||||
local_rank = init_distributed_mode()
|
||||
# 如果启用了分布式训练, 根据 local_rank 设置设备
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
# 设置随机种子, 分布式训练时根据 rank 偏移确保不同进程使用不同种子
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
# ========== 2.配置目录、模型参数、检查 ckp ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
|
||||
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
|
||||
# 创建模型配置, max_seq_len 包含 prompt 和生成的最大长度
|
||||
lm_config = MiniMindConfig(
|
||||
hidden_size=args.hidden_size,
|
||||
num_hidden_layers=args.num_hidden_layers,
|
||||
max_seq_len=args.max_seq_len + args.max_gen_len,
|
||||
use_moe=bool(args.use_moe)
|
||||
)
|
||||
# 如果启用断点续训, 尝试加载检查点
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
# ========== 3.设置混合精度 ==========
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
# https://docs.pytorch.ac.cn/docs/stable/amp
|
||||
# 新版废弃: autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=dtype)
|
||||
|
||||
# ========== 4.配置 wandb ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
@ -248,49 +346,56 @@ if __name__ == "__main__":
|
||||
wandb_run_name = f"MiniMind-GRPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 初始化模型和数据 ==========
|
||||
# ========== 5.初始化模型和数据 ==========
|
||||
# 根据是否为 reasoning 模型选择基础权重
|
||||
base_weight = "reason" if args.reasoning == 1 else "full_sft"
|
||||
# Policy模型
|
||||
# Policy 模型 (将被训练的模型)
|
||||
model, tokenizer = init_model(lm_config, base_weight, device=args.device)
|
||||
# 如果启用 torch.compile 加速
|
||||
if args.use_compile == 1:
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
# Reference模型
|
||||
# Reference 模型 (冻结参数, 用于计算 KL divergence)
|
||||
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
|
||||
ref_model = ref_model.eval().requires_grad_(False)
|
||||
# Reward模型
|
||||
reward_model = AutoModel.from_pretrained(
|
||||
args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True
|
||||
)
|
||||
# Reward 模型 (用于计算奖励分数, 冻结参数)
|
||||
reward_model = AutoModel.from_pretrained(args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True)
|
||||
reward_model = reward_model.to(args.device).eval().requires_grad_(False)
|
||||
# Reward 模型分词器
|
||||
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
|
||||
# 数据和优化器
|
||||
# 加载数据集和创建优化器
|
||||
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
# 计算训练迭代次数和学习率调度器
|
||||
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
|
||||
iters = len(loader_for_count)
|
||||
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
# ========== 6.从 ckp 恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
# 加载模型权重、优化器状态和学习率调度器状态
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scheduler.load_state_dict(ckp_data['scheduler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 7. DDP包模型 ==========
|
||||
# ========== 7.DDP 包装模型 ==========
|
||||
if dist.is_initialized():
|
||||
# 忽略位置编码参数不参与 DDP 同步
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
# ========== 8.开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
# 分布式采样器设置 epoch 以确保不同进程数据不同
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
# 每轮使用不同的随机种子打乱数据
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
# 如果是续训的第一个 epoch, 跳过已训练的 step
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
@ -300,5 +405,5 @@ if __name__ == "__main__":
|
||||
else:
|
||||
grpo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
# ========== 9.清理分布式进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -1,12 +1,13 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 设置包名
|
||||
__package__ = "trainer"
|
||||
# 将项目根目录添加到系统路径, 确保能够导入同级目录的模块
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import warnings
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
@ -18,50 +19,95 @@ from dataset.lm_dataset import SFTDataset
|
||||
from model.model_lora import save_lora, apply_lora
|
||||
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
|
||||
|
||||
# 忽略所有警告信息, 保持输出整洁
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
|
||||
"""
|
||||
训练单个 epoch 的函数
|
||||
|
||||
执行 LoRA 微调的一个完整训练周期, 包括前向传播、反向传播、梯度更新和日志记录
|
||||
|
||||
Args:
|
||||
- epoch: 当前训练轮次
|
||||
- loader: 数据加载器, 提供训练数据批次
|
||||
- iters: 当前 epoch 的总迭代次数
|
||||
- lora_params: LoRA 可训练参数列表, 用于梯度裁剪
|
||||
- start_step: 起始步数 (用于断点续训时跳过已训练部分)
|
||||
- wandb: Weights & Biases 日志对象, 可选
|
||||
"""
|
||||
# 记录训练开始时间, 用于计算 ETA (预计完成时间)
|
||||
start_time = time.time()
|
||||
# 遍历数据加载器中的每个批次, 从 start_step + 1 开始计数
|
||||
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||
# 将输入数据移动到指定设备 (GPU 或 CPU)
|
||||
input_ids = input_ids.to(args.device)
|
||||
labels = labels.to(args.device)
|
||||
# 使用余弦退火策略计算当前步的学习率
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
# 更新优化器中的学习率
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
# 使用自动混合精度上下文进行前向传播 (仅在 GPU 时启用)
|
||||
with autocast_ctx:
|
||||
# 模型前向传播, 计算损失 (主损失 + 辅助损失, 如 MoE 负载均衡损失)
|
||||
res = model(input_ids, labels=labels)
|
||||
loss = res.loss + res.aux_loss
|
||||
# 梯度累积: 将损失除以累积步数, 模拟大批量训练效果
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
# 使用梯度缩放器进行反向传播 (防止 FP16 梯度下溢)
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# 每 accumulation_steps 步执行一次参数更新
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
# 反缩放梯度, 准备进行梯度裁剪
|
||||
scaler.unscale_(optimizer)
|
||||
# 对 LoRA 参数进行梯度裁剪, 防止梯度爆炸
|
||||
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
|
||||
# 执行优化器步骤, 更新参数
|
||||
scaler.step(optimizer)
|
||||
# 更新梯度缩放器的缩放因子
|
||||
scaler.update()
|
||||
# 清空梯度, 释放内存
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# 每隔 log_interval 步或最后一步打印训练日志
|
||||
if step % args.log_interval == 0 or step == iters - 1:
|
||||
# 计算已花费时间
|
||||
spend_time = time.time() - start_time
|
||||
# 恢复原始损失值 (乘以累积步数)
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
# 获取辅助损失值 (MoE 负载均衡损失)
|
||||
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
|
||||
# 计算主损失 (总损失减去辅助损失)
|
||||
current_logits_loss = current_loss - current_aux_loss
|
||||
# 获取当前学习率
|
||||
current_lr = optimizer.param_groups[-1]['lr']
|
||||
# 计算预计剩余时间 (ETA)
|
||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||
# 在主进程打印训练进度日志
|
||||
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
|
||||
# 如果使用 wandb, 记录训练指标
|
||||
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||
|
||||
# 每隔 save_interval 步或最后一步保存模型检查点 (仅主进程执行)
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
# 切换到评估模式, 准备保存模型
|
||||
model.eval()
|
||||
# 构建 LoRA 权重保存路径
|
||||
lora_save_path = f'{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}.pth'
|
||||
# LoRA只保存LoRA权重
|
||||
save_lora(model, lora_save_path)
|
||||
# 保存完整的训练检查点 (包含模型、优化器、学习率调度器状态)
|
||||
lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
# 切换回训练模式
|
||||
model.train()
|
||||
|
||||
# 显式删除张量, 释放 GPU 内存
|
||||
del input_ids, labels, res, loss
|
||||
|
||||
|
||||
@ -78,98 +124,146 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||
parser.add_argument("--log_interval", type=int, default=10, help="日志打印间隔")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
|
||||
parser.add_argument("--save_interval", type=int, default=1, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度 (中文 1 token ≈ 1.5~1.7 字符)")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构 (0=否, 1=是)")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/lora_identity.jsonl", help="LoRA训练数据路径")
|
||||
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练,默认full_sft")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练, 默认 full_sft")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训 (0=否, 1=是)")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="wandb项目名")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速 (0=否, 1=是)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
# ========== 1.初始化环境和随机种子 ==========
|
||||
# 初始化分布式训练模式, 返回本地 GPU 编号
|
||||
local_rank = init_distributed_mode()
|
||||
# 如果启用了分布式训练, 将设备设置为当前进程对应的 GPU
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
# 设置随机种子, 不同进程使用不同的种子以避免数据重复 (基础种子 42 + 进程 rank)
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
# ========== 2.配置目录、模型参数、检查ckp ==========
|
||||
# 确保模型保存目录存在
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
# 创建 MiniMind 模型配置对象, 设置隐藏层维度、层数和是否使用 MoE
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||
# 如果启用了断点续训, 尝试从检查点加载训练状态
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.lora_name, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
# ========== 3.设置混合精度 ==========
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
# 根据参数选择数据类型: bfloat16 或 float16
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
# https://docs.pytorch.ac.cn/docs/stable/amp
|
||||
# 新版废弃: autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
# ========== 4.配置 wandb ==========
|
||||
# 初始化 wandb 对象为 None, 如果不启用则保持为 None
|
||||
wandb = None
|
||||
# 仅在主进程且启用 wandb 时初始化
|
||||
if args.use_wandb and is_main_process():
|
||||
# 导入 swanlab 作为 wandb 的替代 (可能是国内用户适配)
|
||||
import swanlab as wandb
|
||||
# 从检查点恢复 wandb 运行 ID (如果存在)
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
# 如果存在 wandb_id, 则必须恢复原有运行; 否则创建新运行
|
||||
resume = 'must' if wandb_id else None
|
||||
# 构建 wandb 运行名称, 包含关键训练参数
|
||||
wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
|
||||
# 初始化 wandb 运行
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义模型、应用LoRA、冻结非LoRA参数 ==========
|
||||
# ========== 5.定义模型、应用LoRA、冻结非LoRA参数 ==========
|
||||
# 初始化 MiniMind 模型和分词器, 从指定权重加载预训练参数
|
||||
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||
# 如果启用 torch.compile, 编译模型以加速训练 (PyTorch 2.0+ 特性)
|
||||
if args.use_compile == 1:
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
# 应用 LoRA (Low-Rank Adaptation) 到模型, 添加可训练的低秩矩阵
|
||||
apply_lora(model)
|
||||
|
||||
# 统计参数
|
||||
# 统计模型参数量
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
# 统计 LoRA 相关参数量 (参数名中包含 'lora' 的参数)
|
||||
lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name)
|
||||
# 打印模型参数量统计信息
|
||||
Logger(f"LLM 总参数量: {total_params / 1e6:.3f} M")
|
||||
Logger(f"LoRA 参数量: {lora_params_count / 1e6:.3f} M")
|
||||
Logger(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")
|
||||
|
||||
# 冻结非LoRA参数,收集LoRA参数
|
||||
# 冻结非 LoRA 参数, 收集 LoRA 参数
|
||||
lora_params = []
|
||||
for name, param in model.named_parameters():
|
||||
if 'lora' in name:
|
||||
# 启用 LoRA 参数的梯度计算
|
||||
param.requires_grad = True
|
||||
# 将 LoRA 参数添加到可训练参数列表
|
||||
lora_params.append(param)
|
||||
else:
|
||||
# 冻结非 LoRA 参数 (原始模型参数不更新)
|
||||
param.requires_grad = False
|
||||
|
||||
# ========== 6. 定义数据和优化器 ==========
|
||||
# ========== 6.定义数据和优化器 ==========
|
||||
# 创建指令微调数据集 (SFTDataset), 加载 JSONL 格式的训练数据
|
||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
# 如果启用分布式训练, 使用 DistributedSampler 确保数据不重复; 否则为 None
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
# 创建梯度缩放器, 仅在 float16 模式下启用 (用于混合精度训练, 防止梯度下溢)
|
||||
# https://docs.pytorch.ac.cn/docs/stable/amp.html#gradient-scaling
|
||||
# 新版废弃: scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
scaler = torch.amp.GradScaler("cuda", enabled=(args.dtype == 'float16'))
|
||||
# 创建 AdamW 优化器, 仅优化 LoRA 参数
|
||||
optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
|
||||
|
||||
# ========== 7. 从ckp恢复状态 ==========
|
||||
# ========== 7.从检查点恢复状态 ==========
|
||||
# 初始化起始 epoch 和 step 为 0 (从头训练)
|
||||
start_epoch, start_step = 0, 0
|
||||
# 如果存在检查点数据, 恢复训练状态
|
||||
if ckp_data:
|
||||
# 加载模型权重 (strict=False 允许部分加载, 适应 LoRA 场景)
|
||||
model.load_state_dict(ckp_data['model'], strict=False)
|
||||
# 恢复优化器状态 (包括动量等)
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
# 恢复梯度缩放器状态
|
||||
scaler.load_state_dict(ckp_data['scaler'])
|
||||
# 恢复训练轮次
|
||||
start_epoch = ckp_data['epoch']
|
||||
# 恢复训练步数 (默认为 0)
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 8. DDP包模型 ==========
|
||||
# ========== 8.DDP包装模型 ==========
|
||||
# 如果启用分布式训练, 使用 DDP 包装模型
|
||||
if dist.is_initialized():
|
||||
# 设置 DDP 忽略同步的参数 (位置编码相关的缓存, 不需要梯度同步)
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
# 使用 DDP 包装模型, 指定当前进程使用的 GPU 设备
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 9. 开始训练 ==========
|
||||
# ========== 9.开始训练 ==========
|
||||
# 遍历从 start_epoch 到总 epoch 数的每个轮次
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
# 如果存在分布式采样器, 设置当前 epoch (确保数据打乱顺序在不同 epoch 不同)
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
# 设置当前 epoch 的随机种子, 并生成随机打乱的索引列表
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
# 如果是恢复训练的第一个 epoch 且需要跳过部分 step, 设置 skip; 否则为 0
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
# 创建 SkipBatchSampler, 支持跳过指定数量的批次 (断点续训用)
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
# 创建 DataLoader, 使用自定义的 batch_sampler
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
# 如果需要跳过部分 step, 打印提示信息并调用 train_epoch
|
||||
if skip > 0:
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
train_epoch(epoch, loader, len(loader) + skip, lora_params, start_step, wandb)
|
||||
else:
|
||||
# 正常训练, 从 step 0 开始
|
||||
train_epoch(epoch, loader, len(loader), lora_params, 0, wandb)
|
||||
|
||||
# ========== 10. 清理分布进程 ==========
|
||||
# ========== 10.清理分布式进程 ==========
|
||||
# 如果启用了分布式训练, 销毁进程组, 释放资源
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -1,12 +1,13 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 设置包名
|
||||
__package__ = "trainer"
|
||||
# 将项目根目录添加到系统路径, 确保能够导入同级目录的模块
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import warnings
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
@ -22,35 +23,69 @@ from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from dataset.lm_dataset import RLAIFDataset
|
||||
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model
|
||||
|
||||
# 忽略警告信息, 避免输出干扰
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
# 自定义的Critic模型,继承自MiniMindLM
|
||||
class CriticModel(MiniMindForCausalLM):
|
||||
"""
|
||||
Critic 模型, 继承自 MiniMindForCausalLM
|
||||
用于估计状态价值, 为 PPO 算法提供优势函数计算基础
|
||||
"""
|
||||
def __init__(self, params):
|
||||
super().__init__(params)
|
||||
# 替换lm_head为输出单一价值的线性层
|
||||
# 添加价值头, 将隐藏层输出映射为单一标量价值
|
||||
self.value_head = nn.Linear(params.hidden_size, 1)
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, **kwargs):
|
||||
# 使用基础模型获取隐藏状态
|
||||
# 使用基础模型获取隐藏状态输出
|
||||
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
||||
# 对最后一层隐藏状态进行 LayerNorm 归一化
|
||||
hidden_states = self.model.norm(outputs[0])
|
||||
# 使用value_head获取价值估计
|
||||
# 通过 value_head 计算每个位置的价值估计, 并移除最后一个维度
|
||||
values = self.value_head(hidden_states).squeeze(-1)
|
||||
return values
|
||||
|
||||
|
||||
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
"""整合所有奖励函数计算总奖励"""
|
||||
"""
|
||||
整合所有奖励函数计算总奖励
|
||||
|
||||
Args:
|
||||
- prompts: 输入提示列表, list[str]
|
||||
- responses: 模型生成的响应列表, list[str]
|
||||
- reward_model: 奖励模型, 用于评估回答质量
|
||||
- reward_tokenizer: 奖励模型的分词器
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: 每个样本的总奖励值, shape [B]
|
||||
"""
|
||||
def reasoning_model_reward(rewards):
|
||||
# 1. 格式奖励(仅针对训练推理模型时使用)
|
||||
"""
|
||||
推理模型的奖励计算
|
||||
|
||||
包含两部分奖励:
|
||||
1. 格式奖励: 检查回答是否符合 <think>...</think><answer>...</answer> 格式
|
||||
2. 标记奖励: 检查关键标记 (think, answer 标签) 的出现次数
|
||||
|
||||
Args:
|
||||
- rewards: 当前奖励张量, shape [B]
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: 更新后的奖励张量
|
||||
"""
|
||||
# 定义正则表达式模式, 匹配标准的推理格式
|
||||
# 模式1: <think> 和 <answer> 之间只有一个换行
|
||||
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
|
||||
# 模式2: <think> 和 <answer> 之间有两个换行
|
||||
pattern2 = r"^<think>\n.*?\n</think>\n\n<answer>\n.*?\n</answer>$"
|
||||
|
||||
# 对每个响应进行正则匹配
|
||||
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
|
||||
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
|
||||
|
||||
# 计算格式奖励: 匹配任一模式得 0.5 分, 否则 0 分
|
||||
format_rewards = []
|
||||
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
|
||||
if match_pattern:
|
||||
@ -61,8 +96,9 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
format_rewards.append(0.0)
|
||||
rewards += torch.tensor(format_rewards, device=args.device)
|
||||
|
||||
# 2. 标记奖励(防止严格奖励稀疏,仅针对训练推理模型时使用)
|
||||
# 计算标记奖励, 防止严格格式奖励过于稀疏
|
||||
def mark_num(text):
|
||||
"""统计文本中关键标记的出现次数, 每类标记正确出现一次得 0.25 分"""
|
||||
reward = 0
|
||||
if text.count("<think>") == 1:
|
||||
reward += 0.25
|
||||
@ -78,38 +114,46 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
rewards += torch.tensor(mark_rewards, device=args.device)
|
||||
return rewards
|
||||
|
||||
# 初始化奖励张量为零, shape [B]
|
||||
rewards = torch.zeros(len(responses), device=args.device)
|
||||
|
||||
# 格式奖励
|
||||
# 如果是推理模型, 添加格式和标记奖励
|
||||
if args.reasoning == 1:
|
||||
rewards = reasoning_model_reward(rewards)
|
||||
|
||||
# 使用reward model计算整个response的奖励
|
||||
# 使用奖励模型计算 response 的整体质量分数
|
||||
with torch.no_grad():
|
||||
reward_model_scores = []
|
||||
for prompt, response in zip(prompts, responses):
|
||||
# 解析 prompt 中的对话历史
|
||||
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
|
||||
matches = re.findall(pattern, prompt, re.DOTALL)
|
||||
messages = [{"role": role, "content": content.strip()} for role, content in matches]
|
||||
|
||||
# 构建完整的对话, 包含当前 response
|
||||
tmp_chat = messages + [{"role": "assistant", "content": response}]
|
||||
# 调用奖励模型获取评分
|
||||
score = reward_model.get_score(reward_tokenizer, tmp_chat)
|
||||
|
||||
# 将分数裁剪到 [-scale, scale] 范围内, 避免极端值
|
||||
scale = 3.0
|
||||
score = max(min(score, scale), -scale)
|
||||
|
||||
# 当args.reasoning=1时,额外计算<answer>内容的奖励
|
||||
# 如果是推理模型, 额外计算 <answer> 内容部分的奖励
|
||||
if args.reasoning == 1:
|
||||
answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
|
||||
if answer_match:
|
||||
# 提取 answer 标签内的内容
|
||||
answer_content = answer_match.group(1).strip()
|
||||
# 对answer内容单独计算reward
|
||||
# 对 answer 内容单独计算奖励
|
||||
tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
|
||||
answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
|
||||
answer_score = max(min(answer_score, scale), -scale)
|
||||
# 综合评分: 整体回答 40% + answer 内容 60%
|
||||
score = score * 0.4 + answer_score * 0.6
|
||||
reward_model_scores.append(score)
|
||||
|
||||
# 将奖励模型分数转换为张量并累加到总奖励
|
||||
reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
|
||||
rewards += reward_model_scores
|
||||
|
||||
@ -117,81 +161,143 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
|
||||
|
||||
def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, start_step=0, wandb=None):
|
||||
"""
|
||||
PPO 单 epoch 训练函数
|
||||
|
||||
Args:
|
||||
- epoch: 当前训练轮次
|
||||
- loader: 数据加载器
|
||||
- iters: 总迭代次数
|
||||
- old_actor_model: 旧策略模型, 用于计算重要性采样比率
|
||||
- ref_model: 参考模型, 用于计算 KL 散度惩罚
|
||||
- actor_scheduler: Actor 学习率调度器
|
||||
- critic_scheduler: Critic 学习率调度器
|
||||
- reward_model: 奖励模型
|
||||
- reward_tokenizer: 奖励模型分词器
|
||||
- start_step: 起始步数 (用于断点续训)
|
||||
- wandb: 日志记录工具
|
||||
|
||||
Note:
|
||||
使用全局变量: actor_model, critic_model, tokenizer, lm_config, args, actor_optimizer, critic_optimizer, autocast_ctx
|
||||
"""
|
||||
# 设置模型为训练模式
|
||||
actor_model.train()
|
||||
critic_model.train()
|
||||
|
||||
for step, batch in enumerate(loader, start=start_step + 1):
|
||||
# 获取批次中的 prompt 列表
|
||||
prompts = batch["prompt"] # list[str], length B
|
||||
enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True,
|
||||
# 对 prompt 进行编码, 左填充以支持生成
|
||||
enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True,
|
||||
max_length=args.max_seq_len, padding_side="left").to(args.device) # input_ids: [B, P], attention_mask: [B, P]
|
||||
prompt_length = enc.input_ids.shape[1]
|
||||
|
||||
# 使用当前策略生成响应 (不计算梯度)
|
||||
with torch.no_grad():
|
||||
# DDP 模型需要使用 .module 访问 generate 方法
|
||||
# DDP 包装模型需要通过 .module 访问 generate 方法
|
||||
model_for_gen = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
|
||||
gen_out = model_for_gen.generate(
|
||||
input_ids=enc.input_ids, attention_mask=enc.attention_mask,
|
||||
max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
|
||||
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) # [B, P+R]
|
||||
|
||||
# 解码生成的响应文本 (去除 prompt 部分)
|
||||
responses_text = [tokenizer.decode(gen_out[i, prompt_length:], skip_special_tokens=True) for i in range(len(prompts))]
|
||||
# 计算每个样本的奖励值
|
||||
rewards = calculate_rewards(prompts, responses_text, reward_model, reward_tokenizer) # [B]
|
||||
|
||||
# 创建完整序列的注意力掩码 (非填充位置为 1)
|
||||
full_mask = (gen_out != tokenizer.pad_token_id).long() # [B, P+R]
|
||||
# 使用 Critic 模型评估生成序列的价值
|
||||
values_seq = critic_model(input_ids=gen_out, attention_mask=full_mask) # [B, P+R]
|
||||
# 找到每个序列最后一个非填充位置的索引
|
||||
last_indices = (full_mask * torch.arange(full_mask.size(1), device=gen_out.device)).argmax(dim=1)
|
||||
# 提取每个序列最终位置的价值估计
|
||||
values = values_seq[torch.arange(values_seq.size(0), device=values_seq.device), last_indices] # [B]
|
||||
# 计算优势函数: 奖励 - 价值 (使用 detach 阻止梯度流向 Critic)
|
||||
advantages = rewards - values.detach() # [B]
|
||||
|
||||
# 使用混合精度计算 Actor 模型输出
|
||||
with autocast_ctx:
|
||||
res = actor_model(input_ids=gen_out, attention_mask=full_mask)
|
||||
logits = res.logits # [B, P+R, V]
|
||||
# 如果是 MoE 模型, 获取辅助损失; 否则为 0
|
||||
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
|
||||
|
||||
|
||||
# 准备标签 (向右移动一位, 用于计算每个位置的 log prob)
|
||||
labels = gen_out[:, 1:].clone() # [B, P+R-1]
|
||||
# 计算每个 token 的对数概率
|
||||
logp_tokens = F.log_softmax(logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1) # [B, P+R-1]
|
||||
seq_len = gen_out.size(1) - 1
|
||||
# 创建响应部分的掩码 (只计算生成部分, 不计算 prompt)
|
||||
resp_mask = torch.arange(seq_len, device=gen_out.device).unsqueeze(0) >= prompt_length - 1
|
||||
# 最终掩码: 响应部分且非填充位置
|
||||
final_mask = resp_mask & (~labels.eq(tokenizer.pad_token_id)) # [B, P+R-1]
|
||||
# 计算 Actor 策略下整个响应的总对数概率
|
||||
actor_logp = (logp_tokens * final_mask).sum(dim=1) # [B]
|
||||
|
||||
# 使用旧策略和参考模型计算对数概率 (不计算梯度)
|
||||
with torch.no_grad():
|
||||
# 旧策略模型的输出
|
||||
old_logits = old_actor_model(input_ids=gen_out, attention_mask=full_mask).logits # [B, P+R, V]
|
||||
old_logp_tokens = F.log_softmax(old_logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1) # [B, P+R-1]
|
||||
old_logp = (old_logp_tokens * final_mask).sum(dim=1) # [B]
|
||||
|
||||
|
||||
# 参考模型的输出 (用于 KL 散度计算)
|
||||
ref_logits = ref_model(input_ids=gen_out, attention_mask=full_mask).logits # [B, P+R, V]
|
||||
ref_logp_tokens = F.log_softmax(ref_logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1) # [B, P+R-1]
|
||||
ref_logp = (ref_logp_tokens * final_mask).sum(dim=1) # [B]
|
||||
|
||||
# 计算与旧策略的 KL 散度
|
||||
kl = (actor_logp - old_logp).mean() # scalar
|
||||
# 计算与参考模型的 KL 散度 (作为惩罚项)
|
||||
kl_ref = (actor_logp - ref_logp).mean() # scalar
|
||||
# 计算重要性采样比率
|
||||
ratio = torch.exp(actor_logp - old_logp) # [B]
|
||||
# PPO 裁剪目标的第一项
|
||||
surr1 = ratio * advantages # [B]
|
||||
# PPO 裁剪目标的第二项 (裁剪比率防止过大更新)
|
||||
surr2 = torch.clamp(ratio, 1.0 - args.clip_epsilon, 1.0 + args.clip_epsilon) * advantages # [B]
|
||||
# 策略损失: 取裁剪后的最小值, 取负号后求平均 (因为我们要最大化目标)
|
||||
policy_loss = -torch.min(surr1, surr2).mean() # scalar
|
||||
# 价值损失: Critic 预测值与奖励的均方误差
|
||||
value_loss = F.mse_loss(values, rewards) # scalar
|
||||
# 总损失: 策略损失 + 价值系数 * 价值损失 + KL 系数 * KL 惩罚 + 辅助损失
|
||||
loss = (policy_loss + args.vf_coef * value_loss + args.kl_coef * kl_ref + aux_loss) / args.accumulation_steps # scalar
|
||||
# 反向传播计算梯度
|
||||
loss.backward()
|
||||
|
||||
# 梯度累积: 达到指定步数后执行参数更新
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
# 对 Actor 和 Critic 的梯度进行裁剪, 防止梯度爆炸
|
||||
clip_grad_norm_(actor_model.parameters(), args.grad_clip)
|
||||
clip_grad_norm_(critic_model.parameters(), args.grad_clip)
|
||||
# 执行优化器参数更新
|
||||
actor_optimizer.step()
|
||||
critic_optimizer.step()
|
||||
# 更新学习率
|
||||
actor_scheduler.step()
|
||||
critic_scheduler.step()
|
||||
# 清空梯度
|
||||
actor_optimizer.zero_grad()
|
||||
critic_optimizer.zero_grad()
|
||||
|
||||
# 只在主进程打印日志和记录指标
|
||||
if is_main_process():
|
||||
# 提取响应部分的 token IDs
|
||||
response_ids = gen_out[:, enc.input_ids.shape[1]:]
|
||||
# 检测每个响应是否包含 EOS 标记
|
||||
is_eos = (response_ids == tokenizer.eos_token_id)
|
||||
# 找到第一个 EOS 的位置
|
||||
eos_indices = torch.argmax(is_eos.int(), dim=1)
|
||||
# 检查每个响应是否实际包含 EOS
|
||||
has_eos = is_eos.any(dim=1)
|
||||
# 计算实际响应长度 (如果有 EOS 则为 EOS 位置+1, 否则为完整长度)
|
||||
lengths = torch.where(has_eos, eos_indices + 1, torch.tensor(response_ids.shape[1], device=is_eos.device))
|
||||
# 计算平均响应长度
|
||||
avg_len = lengths.float().mean()
|
||||
|
||||
# 提取各项指标数值
|
||||
actor_loss_val = policy_loss.item()
|
||||
critic_loss_val = value_loss.item()
|
||||
current_aux_loss = aux_loss.item()
|
||||
@ -202,6 +308,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
|
||||
actor_lr = actor_optimizer.param_groups[0]['lr']
|
||||
critic_lr = critic_optimizer.param_groups[0]['lr']
|
||||
|
||||
# 记录到 wandb
|
||||
if wandb is not None:
|
||||
wandb.log({
|
||||
"actor_loss": actor_loss_val,
|
||||
@ -214,35 +321,45 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
|
||||
"actor_lr": actor_lr,
|
||||
})
|
||||
|
||||
# 打印训练日志
|
||||
Logger(f"Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), "
|
||||
f"Actor Loss: {actor_loss_val:.4f}, Critic Loss: {critic_loss_val:.4f}, Aux Loss: {current_aux_loss:.4f}, "
|
||||
f"Reward: {reward_val:.4f}, KL: {kl_val:.4f}, KL_ref: {kl_ref_val:.4f}, "
|
||||
f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.8f}, Critic LR: {critic_lr:.8f}")
|
||||
|
||||
# 按指定频率更新旧策略模型
|
||||
if (step + 1) % args.update_old_actor_freq == 0:
|
||||
# 解包 DDP 模型获取原始模型
|
||||
raw_actor = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
|
||||
raw_actor = getattr(raw_actor, '_orig_mod', raw_actor)
|
||||
# 获取当前策略状态字典
|
||||
state_dict = raw_actor.state_dict()
|
||||
# 将参数 detached 后加载到旧策略模型
|
||||
old_actor_model.load_state_dict({k: v.detach().cpu() for k, v in state_dict.items()})
|
||||
old_actor_model.to(args.device)
|
||||
|
||||
# 按指定间隔保存检查点 (只在主进程执行)
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
actor_model.eval()
|
||||
# 根据是否使用 MoE 构建文件名后缀
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
# 解包 DDP 模型获取原始 Actor 模型
|
||||
raw_actor = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
|
||||
raw_actor = getattr(raw_actor, '_orig_mod', raw_actor)
|
||||
# 获取状态字典并转换为半精度后保存
|
||||
actor_state = raw_actor.state_dict()
|
||||
torch.save({k: v.half().cpu() for k, v in actor_state.items()}, ckp)
|
||||
|
||||
# 使用 lm_checkpoint 保存完整状态(包括 critic)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=actor_model, optimizer=actor_optimizer,
|
||||
|
||||
# 使用 lm_checkpoint 保存完整训练状态 (包括 Critic)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=actor_model, optimizer=actor_optimizer,
|
||||
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints',
|
||||
scheduler=actor_scheduler, critic_model=critic_model,
|
||||
scheduler=actor_scheduler, critic_model=critic_model,
|
||||
critic_optimizer=critic_optimizer, critic_scheduler=critic_scheduler)
|
||||
actor_model.train()
|
||||
del actor_state
|
||||
|
||||
# 显式删除中间变量, 释放显存
|
||||
del enc, gen_out, responses_text, rewards, full_mask, values_seq, values, advantages
|
||||
del logits, labels, logp_tokens, final_mask, actor_logp, old_logits, old_logp, ref_logits, ref_logp
|
||||
del kl, kl_ref, ratio, surr1, surr2, policy_loss, value_loss, loss
|
||||
@ -265,117 +382,174 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构 (0=否, 1=是)")
|
||||
parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
|
||||
parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
|
||||
parser.add_argument("--clip_epsilon", type=float, default=0.1, help="PPO裁剪参数")
|
||||
parser.add_argument("--vf_coef", type=float, default=0.5, help="Value function系数")
|
||||
parser.add_argument("--kl_coef", type=float, default=0.02, help="KL散度惩罚系数")
|
||||
parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型(0=普通模型,1=推理模型)')
|
||||
parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型 (0=普通模型, 1=推理模型)')
|
||||
parser.add_argument("--update_old_actor_freq", type=int, default=4, help="更新old_actor_model的频率")
|
||||
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训 (0=否, 1=是)")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-PPO", help="wandb项目名")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速 (0=否, 1=是)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
# ========== 1.初始化环境和随机种子 ==========
|
||||
# 初始化分布式训练环境, 获取本地 rank
|
||||
local_rank = init_distributed_mode()
|
||||
# 如果处于分布式训练环境, 根据 local_rank 设置设备
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
# 设置随机种子, 分布式环境下根据 rank 添加偏移以确保不同进程种子不同
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
|
||||
# ========== 2.配置目录、模型参数、检查 ckp ==========
|
||||
# 创建模型保存目录
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
# 创建模型配置对象
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||
# 如果启用断点续训, 尝试加载检查点数据
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
|
||||
# ========== 3.设置混合精度 ==========
|
||||
# 判断设备类型
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
# 设置数据类型
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
# https://docs.pytorch.ac.cn/docs/stable/amp
|
||||
# 新版废弃: autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=dtype)
|
||||
|
||||
# ========== 4. 配置 wandb ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
# 使用 swanlab 作为 wandb 的替代
|
||||
import swanlab as wandb
|
||||
# 尝试从检查点恢复 wandb ID
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
# 构建运行名称
|
||||
wandb_run_name = f"MiniMind-PPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
|
||||
# 初始化 wandb
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
|
||||
# ========== 5. 初始化模型和数据 ==========
|
||||
# 根据是否训练推理模型选择基础权重名称
|
||||
base_weight = "reason" if args.reasoning == 1 else "full_sft"
|
||||
# Actor模型
|
||||
|
||||
# Actor 模型 (策略网络, 需要训练)
|
||||
actor_model, tokenizer = init_model(lm_config, base_weight, device=args.device)
|
||||
if args.use_compile == 1:
|
||||
# 使用 torch.compile 加速模型 (需要 PyTorch 2.0+)
|
||||
actor_model = torch.compile(actor_model)
|
||||
Logger('torch.compile enabled')
|
||||
# Old Actor模型
|
||||
|
||||
# Old Actor 模型 (旧策略, 用于 PPO 的重要性采样)
|
||||
old_actor_model, _ = init_model(lm_config, base_weight, device=args.device)
|
||||
# 设置为评估模式, 不计算梯度
|
||||
old_actor_model = old_actor_model.eval().requires_grad_(False)
|
||||
# Reference模型
|
||||
|
||||
# Reference 模型 (参考模型, 用于 KL 散度惩罚)
|
||||
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
|
||||
ref_model = ref_model.eval().requires_grad_(False)
|
||||
# Critic模型
|
||||
|
||||
# Critic 模型 (价值网络, 估计状态价值)
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
# 加载基础模型权重初始化 Critic
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
critic_model = CriticModel(lm_config)
|
||||
critic_model.load_state_dict(state_dict, strict=False)
|
||||
critic_model = critic_model.to(args.device)
|
||||
# Reward模型
|
||||
reward_model = AutoModel.from_pretrained(
|
||||
args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True
|
||||
)
|
||||
|
||||
# Reward 模型 (奖励模型, 提供奖励信号)
|
||||
reward_model = AutoModel.from_pretrained(args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True)
|
||||
reward_model = reward_model.to(args.device).eval().requires_grad_(False)
|
||||
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
|
||||
# 数据和优化器
|
||||
|
||||
# 数据集和优化器配置
|
||||
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=(args.max_seq_len + args.max_gen_len))
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
actor_optimizer = optim.AdamW(actor_model.parameters(), lr=args.learning_rate)
|
||||
critic_optimizer = optim.AdamW(critic_model.parameters(), lr=args.critic_learning_rate)
|
||||
# 计算训练迭代次数
|
||||
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
|
||||
iters = len(loader_for_count)
|
||||
# 计算总优化步数 (考虑梯度累积)
|
||||
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
|
||||
# 学习率调度器 (余弦退火)
|
||||
actor_scheduler = CosineAnnealingLR(actor_optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
|
||||
critic_scheduler = CosineAnnealingLR(critic_optimizer, T_max=total_optimizer_steps, eta_min=args.critic_learning_rate / 10)
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
|
||||
# ==========6. 从ckp 恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
# 恢复模型权重
|
||||
actor_model.load_state_dict(ckp_data['model'])
|
||||
critic_model.load_state_dict(ckp_data['critic_model'])
|
||||
# 恢复优化器状态
|
||||
actor_optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
critic_optimizer.load_state_dict(ckp_data['critic_optimizer'])
|
||||
# 恢复学习率调度器状态
|
||||
actor_scheduler.load_state_dict(ckp_data['scheduler'])
|
||||
critic_scheduler.load_state_dict(ckp_data['critic_scheduler'])
|
||||
# 恢复训练进度
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 7. DDP包模型 ==========
|
||||
|
||||
# ========== 7.DDP 包装模型 ==========
|
||||
if dist.is_initialized():
|
||||
# 设置 DDP 忽略的参数 (位置编码相关, 不需要梯度同步)
|
||||
actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
# 包装为 DDP 模型
|
||||
actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank])
|
||||
critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank])
|
||||
# 确保旧策略模型在正确设备上
|
||||
old_actor_model.to(args.device)
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
|
||||
# ========== 8.开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
# 设置分布式采样器的 epoch (确保每个 epoch 数据打乱方式不同)
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
# 每个 epoch 使用不同的随机种子
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
# 如果是续训的第一个 epoch, 需要跳过已经训练过的 step
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
# 创建支持跳过指定批次数量的采样器
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
# 创建数据加载器
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
ppo_train_epoch(epoch, loader, len(loader) + skip, old_actor_model, ref_model,
|
||||
actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, start_step, wandb)
|
||||
if skip > 0:
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step, 从step {start_step + 1}开始')
|
||||
ppo_train_epoch(
|
||||
epoch,
|
||||
loader,
|
||||
len(loader) + skip,
|
||||
old_actor_model,
|
||||
ref_model,
|
||||
actor_scheduler,
|
||||
critic_scheduler,
|
||||
reward_model,
|
||||
reward_tokenizer,
|
||||
start_step,
|
||||
wandb)
|
||||
else:
|
||||
ppo_train_epoch(epoch, loader, len(loader), old_actor_model, ref_model,
|
||||
actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
ppo_train_epoch(
|
||||
epoch,
|
||||
loader,
|
||||
len(loader),
|
||||
old_actor_model,
|
||||
ref_model,
|
||||
actor_scheduler,
|
||||
critic_scheduler,
|
||||
reward_model,
|
||||
reward_tokenizer,
|
||||
0,
|
||||
wandb)
|
||||
|
||||
# ========== 9.清理分布式进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -1,72 +1,131 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 设置包名
|
||||
__package__ = "trainer"
|
||||
# 将项目根目录添加到系统路径, 确保能够导入同级目录的模块
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import warnings
|
||||
import torch
|
||||
# 分布式训练支持
|
||||
import torch.distributed as dist
|
||||
|
||||
# 上下文管理器
|
||||
# https://docs.python.org/zh-cn/3.14/library/contextlib.html#contextlib.nullcontext
|
||||
from contextlib import nullcontext
|
||||
from torch import optim, nn
|
||||
# 分布式数据并行 DDP
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
# 数据加载器和分布式采样器
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from model.model_minimind import MiniMindConfig
|
||||
from dataset.lm_dataset import PretrainDataset
|
||||
# 预训练数据集
|
||||
from dataset.lm_dataset import PretrainDataset
|
||||
# 导入训练工具函数:学习率计算、日志记录、检查点保存、分布式初始化等
|
||||
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
|
||||
|
||||
# 忽略所有警告信息,减少输出干扰
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
"""
|
||||
执行单个训练 epoch
|
||||
|
||||
Args:
|
||||
- epoch: 当前 epoch 编号
|
||||
- loader: 数据加载器
|
||||
- iters: 当前 epoch 的总迭代次数
|
||||
- start_step: 起始 step 编号 (用于断点续训)
|
||||
- wandb: 实验日志工具实例
|
||||
"""
|
||||
# 记录当前 epoch 开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 遍历数据加载器, 获取批次数据
|
||||
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||
# 将 数据 & 标签 移动到指定设备
|
||||
input_ids = input_ids.to(args.device)
|
||||
labels = labels.to(args.device)
|
||||
|
||||
# 根据当前 step 计算当前学习率
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
# 更新优化器中所有参数组的学习率
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
# 使用混合精度上下文进行前向传播
|
||||
# - 在 __name__ == "__main__" 中定义
|
||||
# - autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(dtype=dtype)
|
||||
with autocast_ctx:
|
||||
# 模型前向传播, 其中包含损失
|
||||
res = model(input_ids, labels=labels)
|
||||
# 合并主损失和辅助损失 (如 MoE 的负载均衡损失)
|
||||
loss = res.loss + res.aux_loss
|
||||
# 根据梯度累积步数进行缩放
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
# 反向传播计算梯度
|
||||
# - 在 __name__ == "__main__" 中定义
|
||||
# - scaler = torch.amp.GradScaler("cuda", enabled=(args.dtype == 'float16'))
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# 当达到梯度累积步数时,执行优化器更新
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
# 取消梯度缩放, 为优化器步骤做准备
|
||||
scaler.unscale_(optimizer)
|
||||
# 对梯度进行裁剪, 防止梯度爆炸
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
|
||||
# 执行优化器步骤更新参数
|
||||
scaler.step(optimizer)
|
||||
# 更新缩放器状态
|
||||
scaler.update()
|
||||
|
||||
# 清空梯度, 设置为 None 以节省内存
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# 达到日志打印间隔时, 记录训练状态
|
||||
if step % args.log_interval == 0 or step == iters - 1:
|
||||
# 计算已消耗时间
|
||||
spend_time = time.time() - start_time
|
||||
# 获取当前损失值
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
# 获取辅助损失值
|
||||
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
|
||||
# 计算 logits 损失 (主损失)
|
||||
current_logits_loss = current_loss - current_aux_loss
|
||||
# 获取当前学习率
|
||||
current_lr = optimizer.param_groups[-1]['lr']
|
||||
# 估算剩余时间 (分钟)
|
||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||
# 打印训练日志
|
||||
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
|
||||
# 记录到 wandb (如果启用)
|
||||
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||
|
||||
# 达到保存间隔且为主进程时, 保存模型检查点
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
# 设置为评估模式
|
||||
model.eval()
|
||||
# 生成模型文件名后缀, 区分 MoE 和普通模型
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
# 获取原始模型 (移除 DDP 包装)
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
# 获取模型状态字典
|
||||
state_dict = raw_model.state_dict()
|
||||
# 保存模型权重 (转换为半精度以节省空间)
|
||||
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||
# 保存完整训练检查点 (包含配置、优化器状态等)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
# 恢复训练模式
|
||||
model.train()
|
||||
# 释放显存
|
||||
del state_dict
|
||||
|
||||
# 清理当前批次的数据, 释放显存
|
||||
del input_ids, labels, res, loss
|
||||
|
||||
|
||||
@ -74,7 +133,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind Pretraining")
|
||||
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||
parser.add_argument('--save_weight', default='pretrain', type=str, help="保存权重的前缀名")
|
||||
parser.add_argument("--epochs", type=int, default=1, help="训练轮数(建议1轮zero或2-6轮充分训练)")
|
||||
parser.add_argument("--epochs", type=int, default=1, help="训练轮数 (建议 1 轮 zero 或 2-6 轮充分训练)")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||
@ -86,52 +145,76 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度 (中文 1token≈1.5~1.7 字符)")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用 MoE 架构 (0=否, 1=是)")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl", help="预训练数据路径")
|
||||
parser.add_argument('--from_weight', default='none', type=str, help="基于哪个权重训练,为none则从头开始")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||
parser.add_argument('--from_weight', default='none', type=str, help="基于哪个权重训练, 为 none 则从头开始")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测 & 续训 (0=否, 1=是)")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用 wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb 项目名")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用 torch.compile 加速 (0=否, 1=是)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
# ========== 1.初始化环境和随机种子 ===========
|
||||
# 初始化分布式训练模式, 返回本地 GPU 编号
|
||||
local_rank = init_distributed_mode()
|
||||
# 如果已初始化分布式环境, 更新设备为对应 GPU
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
# 设置随机种子, 确保实验可复现 (不同进程使用不同种子避免同步)
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
# ========== 2.配置目录、模型参数、检查 ckp ===========
|
||||
# 创建保存目录 (如果不存在)
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
# 根据命令行参数创建 MiniMind 模型配置
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||
# 如果启用断点续训, 尝试加载已有检查点信息
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
# ========== 3.设置混合精度 ==========
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
# 根据参数选择精度类型 (bfloat16 或 float16)
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
# https://docs.pytorch.ac.cn/docs/stable/amp
|
||||
# 新版废弃: autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
# ========== 4.配置 wandb ==========
|
||||
wandb = None
|
||||
# 仅主进程启用 wandb, 避免日志重复
|
||||
if args.use_wandb and is_main_process():
|
||||
# 导入日志工具
|
||||
import swanlab as wandb
|
||||
# 从检查点获取 wandb run ID (用于恢复实验)
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
# 设置恢复模式:如果有 ID 则恢复之前实验, 否则创建新实验
|
||||
resume = 'must' if wandb_id else None
|
||||
# 生成实验名称, 包含关键超参数
|
||||
wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
# 初始化 wandb 实验
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义模型、数据、优化器 ==========
|
||||
# ========== 5.定义模型、数据、优化器 ==========
|
||||
# 初始化模型和分词器
|
||||
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||
# 如果启用 torch.compile, 对模型进行编译优化 (需要 PyTorch 2.0+ & Linux)
|
||||
if args.use_compile == 1:
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
# 创建预训练数据集
|
||||
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
# 如果使用分布式训练, 创建分布式采样器
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
# 创建梯度缩放器 (用于混合精度训练)
|
||||
# https://docs.pytorch.ac.cn/docs/stable/amp.html#gradient-scaling
|
||||
# 新版废弃: scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
scaler = torch.amp.GradScaler("cuda", enabled=(args.dtype == 'float16'))
|
||||
# 创建 AdamW 优化器
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
# ========== 6.从 ckp 中恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
# 如果有检查点数据, 恢复训练状态
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
@ -139,23 +222,33 @@ if __name__ == "__main__":
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 7. DDP包模型 ==========
|
||||
# ========== 7.DDP 包装模型 ==========
|
||||
# 如果使用分布式训练, 使用包装模型为 DDP 模式
|
||||
if dist.is_initialized():
|
||||
# 忽略频率向量 (避免 DDP 重复广播)
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
# ========== 8.开始训练 ==========
|
||||
# 遍历所有 epoch
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
# 如果使用分布式采样器, 设置 epoch 以确保数据打乱
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
# 设置随机种子并生成打乱的索引顺序
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
# 计算需要跳过的 step 数 (断点续训时)
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
# 创建批次采样器, 支持跳过指定数量的批次
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
# 创建数据加载器
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
# 如果有跳过, 跳过前 start_step 个 step
|
||||
if skip > 0:
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前 {start_step} 个 step, 从 step {start_step + 1} 开始')
|
||||
train_epoch(epoch, loader, len(loader) + skip, start_step, wandb)
|
||||
else:
|
||||
train_epoch(epoch, loader, len(loader), 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
# ========== 9.清理分布进程 ==========
|
||||
# 销毁分布式进程组, 释放资源
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -1,12 +1,14 @@
|
||||
# 导入标准库
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 设置包名, 确保能够正确导入同级目录的模块
|
||||
__package__ = "trainer"
|
||||
# 将项目根目录添加到系统路径
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import warnings
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
@ -17,74 +19,138 @@ from model.model_minimind import MiniMindConfig
|
||||
from dataset.lm_dataset import SFTDataset
|
||||
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
|
||||
|
||||
# 忽略警告信息, 减少控制台输出干扰
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=None):
|
||||
start_of_think_ids = tokenizer('<think>').input_ids
|
||||
"""
|
||||
单个 epoch 的训练函数
|
||||
|
||||
执行模型的推理蒸馏训练, 包括前向传播、损失计算、梯度累积和反向传播
|
||||
使用特殊 token (<think>, </think>, <answer>, </answer>) 对思考过程和答案部分进行加权
|
||||
|
||||
Args:
|
||||
- epoch: 当前训练轮次
|
||||
- loader: 数据加载器, 提供 (input_ids, labels) 批次
|
||||
- iters: 当前 epoch 的总迭代步数
|
||||
- tokenizer: 分词器对象, 用于获取特殊 token 的 ID
|
||||
- lm_config: 模型配置对象, 包含 use_moe 等属性
|
||||
- start_step: 起始步数, 用于断点续训时跳过已训练的步数
|
||||
- wandb: Weights & Biases 日志对象, 可选
|
||||
"""
|
||||
# 获取特殊 token 的 ID, 用于在损失计算中对思考过程加权
|
||||
# <think> 和 </think> 标记思考过程的开始和结束
|
||||
start_of_think_ids = tokenizer(' <think>').input_ids
|
||||
end_of_think_ids = tokenizer('</think>').input_ids
|
||||
# <answer> 和 </answer> 标记答案的开始和结束
|
||||
start_of_answer_ids = tokenizer('<answer>').input_ids
|
||||
end_of_answer_ids = tokenizer('</answer>').input_ids
|
||||
# 使用不 reduction 的交叉熵损失, 以便后续应用自定义 loss mask
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
# 记录 epoch 开始时间, 用于计算训练速度和 ETA
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# 遍历数据加载器, 从 start_step + 1 开始计数
|
||||
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||
# 将输入数据移动到训练设备
|
||||
input_ids = input_ids.to(args.device)
|
||||
labels = labels.to(args.device)
|
||||
# 计算当前步的学习率, 使用余弦退火策略
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
# 更新优化器的学习率
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
# 使用自动混合精度上下文进行前向传播
|
||||
with autocast_ctx:
|
||||
# 模型前向传播, 获取输出结果
|
||||
res = model(input_ids)
|
||||
# 对 logits 和 labels 进行移位, 以进行因果语言建模
|
||||
# shift_logits: 预测下一个 token 的 logits
|
||||
# shift_labels: 实际的下一个 token 的标签
|
||||
shift_logits = res.logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# 计算每个位置的交叉熵损失
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size())
|
||||
|
||||
# 构建损失掩码, 标记有效标签位置 (非 -100 的 padding)
|
||||
loss_mask = (shift_labels != -100).float()
|
||||
sp_ids = torch.isin(shift_labels.view(-1),
|
||||
torch.tensor(start_of_think_ids + end_of_think_ids
|
||||
+ start_of_answer_ids + end_of_answer_ids
|
||||
).to(args.device))
|
||||
# 找出特殊 token 的位置 (思考过程和答案的标记)
|
||||
sp_ids = torch.isin(
|
||||
shift_labels.view(-1),
|
||||
torch.tensor(start_of_think_ids + end_of_think_ids + start_of_answer_ids + end_of_answer_ids).to(args.device)
|
||||
)
|
||||
# 对 loss mask 进行加权: 特殊 token 位置权重设为 10, 增强对推理过程的学习
|
||||
loss_mask_flat = loss_mask.view(-1)
|
||||
loss_mask_sum = loss_mask_flat.sum()
|
||||
loss_mask_flat[sp_ids] = 10
|
||||
loss_mask = loss_mask_flat.view(shift_labels.size())
|
||||
# 计算加权后的 logits 损失
|
||||
logits_loss = (loss * loss_mask).sum() / loss_mask_sum
|
||||
# 总损失 = 加权 logits 损失 + MoE 辅助损失 (如使用 MoE)
|
||||
loss = logits_loss + res.aux_loss
|
||||
# 梯度累积: 将损失除以累积步数, 模拟大批量训练
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
# 使用梯度缩放器进行反向传播, 防止 FP16/BF16 梯度下溢
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# 当达到梯度累积步数时, 执行优化器更新
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
# 取消梯度缩放, 以便进行梯度裁剪
|
||||
scaler.unscale_(optimizer)
|
||||
# 梯度裁剪, 防止梯度爆炸, 保持训练稳定
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
# 执行优化器更新
|
||||
scaler.step(optimizer)
|
||||
# 更新梯度缩放器的缩放因子
|
||||
scaler.update()
|
||||
# 清空梯度, set_to_none=True 节省内存
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# 按日志间隔打印训练状态, 或在最后一个 step 打印
|
||||
if step % args.log_interval == 0 or step == iters - 1:
|
||||
# 计算已花费的时间和各项损失值
|
||||
spend_time = time.time() - start_time
|
||||
# 恢复原始损失值 (之前除以了 accumulation_steps)
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
# 获取 MoE 辅助损失 (如果存在)
|
||||
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
|
||||
current_logits_loss = logits_loss.item()
|
||||
current_lr = optimizer.param_groups[-1]['lr']
|
||||
# 计算预计剩余时间 (分钟)
|
||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||
# 打印训练日志
|
||||
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
|
||||
# 如果使用 wandb, 记录训练指标
|
||||
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||
|
||||
# 按保存间隔保存模型检查点, 仅在主进程执行
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
# 切换到评估模式, 准备保存模型
|
||||
model.eval()
|
||||
# 根据是否使用 MoE 确定文件名后缀
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
# 构建模型保存路径
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
# 如果是 DDP 模型, 解包获取原始模型
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
# 处理 torch.compile 包装的情况
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
# 获取模型状态字典
|
||||
state_dict = raw_model.state_dict()
|
||||
# 将权重转换为半精度 (FP16) 并保存到 CPU, 减小文件大小
|
||||
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||
# 保存完整检查点 (包含优化器状态等, 用于断点续训)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
# 切换回训练模式
|
||||
model.train()
|
||||
# 释放状态字典内存
|
||||
del state_dict
|
||||
|
||||
# 删除批次数据, 释放 GPU 内存
|
||||
del input_ids, labels, res, loss
|
||||
|
||||
|
||||
@ -104,76 +170,110 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--max_seq_len', default=720, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||
parser.add_argument('--max_seq_len', default=720, type=int, help="训练的最大截断长度 (中文 1 token ≈ 1.5 ~ 1.7 字符)")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构 (0=否, 1=是)")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/r1_mix_1024.jsonl", help="推理蒸馏数据路径")
|
||||
parser.add_argument('--from_weight', default='dpo', type=str, help="基于哪个权重训练,默认dpo")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||
parser.add_argument('--from_weight', default='dpo', type=str, help="基于哪个权重训练, 默认dpo")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训 (0=否, 1=是)")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Reasoning", help="wandb项目名")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速 (0=否, 1=是)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
# ========== 1.初始化分布式环境和随机种子 ==========
|
||||
# 初始化分布式训练模式, 获取本地 rank
|
||||
local_rank = init_distributed_mode()
|
||||
# 如果分布式训练已初始化, 根据 local_rank 设置设备
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
# 设置随机种子, 分布式环境下每个 rank 使用不同的种子
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
|
||||
# ========== 2.配置目录、模型参数、检查检查点 ==========
|
||||
# 创建保存目录 (如果不存在)
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
# 初始化模型配置
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||
# 如果启用断点续训, 尝试加载检查点
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
|
||||
# ========== 3.设置混合精度训练 ==========
|
||||
# 确定设备类型
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
# 根据参数设置数据类型 (BF16 或 FP16)
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
# https://docs.pytorch.ac.cn/docs/stable/amp
|
||||
# 新版废弃: autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=dtype)
|
||||
|
||||
# ========== 4.配置 wandb 日志 ==========
|
||||
wandb = None
|
||||
# 仅在主进程且启用 wandb 时初始化
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
# 断点续训时恢复 wandb 运行 ID
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
# 构建运行名称
|
||||
wandb_run_name = f"MiniMind-Reasoning-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义模型、数据、优化器 ==========
|
||||
|
||||
# ========== 5.定义模型、数据集、优化器 ==========
|
||||
# 初始化模型和分词器
|
||||
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||
# 如果使用 torch.compile 加速 (PyTorch 2.0+ 特性)
|
||||
if args.use_compile == 1:
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
# 创建训练数据集
|
||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
# 如果是分布式训练, 创建分布式采样器
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
# 创建梯度缩放器, 仅在 float16 模式下启用 (用于混合精度训练, 防止梯度下溢)
|
||||
# https://docs.pytorch.ac.cn/docs/stable/amp.html#gradient-scaling
|
||||
# 新版废弃: scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
scaler = torch.amp.GradScaler("cuda", enabled=(args.dtype == 'float16'))
|
||||
# 创建 AdamW 优化器
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
|
||||
# ========== 6.从检查点恢复训练状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
# 加载模型权重
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
# 加载优化器状态
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
# 加载梯度缩放器状态
|
||||
scaler.load_state_dict(ckp_data['scaler'])
|
||||
# 恢复训练轮次和步数
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 7. DDP包模型 ==========
|
||||
|
||||
# ========== 7.DDP 包装模型 ==========
|
||||
# 如果使用分布式训练, 用 DistributedDataParallel 包装模型
|
||||
if dist.is_initialized():
|
||||
# 设置不需要同步的参数 (RoPE 位置编码的频率参数)
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
|
||||
# ========== 8.开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
# 分布式采样器设置 epoch, 确保每个 epoch 的数据打乱顺序不同
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
# 设置当前 epoch 的随机种子, 并打乱数据顺序
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
# 如果是恢复训练的第一个 epoch, 计算需要跳过的步数
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
# 创建支持跳过批次的采样器
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
# 创建数据加载器
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
# 执行训练 epoch
|
||||
if skip > 0:
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step, 从step {start_step + 1}开始')
|
||||
train_epoch(epoch, loader, len(loader) + skip, tokenizer, lm_config, start_step, wandb)
|
||||
else:
|
||||
train_epoch(epoch, loader, len(loader), tokenizer, lm_config, 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
|
||||
# ========== 9.清理分布式进程 ==========
|
||||
# 训练完成后销毁进程组, 释放资源
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -1,13 +1,14 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 设置包名, 用于相对导入
|
||||
__package__ = "trainer"
|
||||
# 将项目根目录添加到系统路径, 确保能够导入同级目录的模块
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import gc
|
||||
import warnings
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers import AutoTokenizer
|
||||
@ -21,36 +22,84 @@ from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from dataset.lm_dataset import RLAIFDataset
|
||||
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model
|
||||
|
||||
# 忽略警告信息, 保持输出整洁
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
class AutoAdaptiveValueTracker:
|
||||
"""SPO自适应价值追踪器"""
|
||||
"""
|
||||
SPO 自适应价值追踪器
|
||||
|
||||
使用 Beta 分布来估计奖励基线, 支持基于 KL 散度的自适应 rho 更新
|
||||
"""
|
||||
def __init__(self, rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96):
|
||||
# rho 计算模式: 'kl' 表示基于 KL 散度, 'constant' 表示固定值
|
||||
self.rho_mode = rho_mode
|
||||
# 固定 rho 模式下的常数值
|
||||
self.rho_const = rho_const
|
||||
# KL 散度半衰期参数, 用于控制 rho 随 KL 变化的敏感度
|
||||
self.D_half = D_half
|
||||
# rho 值的裁剪上下限, 防止更新过快或过慢
|
||||
self.clip_lower = clip_lower
|
||||
self.clip_upper = clip_upper
|
||||
# 初始化 Beta 分布的参数 alpha 和 beta
|
||||
N_init = 1.0 / (1.0 - self.clip_lower)
|
||||
self.alpha = 0.5 * N_init
|
||||
self.beta = 0.5 * N_init
|
||||
# 保存上一轮迭代的平均对数概率, 用于计算 KL 散度
|
||||
self.old_mean_logprob = None
|
||||
|
||||
def get_baselines(self, batch_size):
|
||||
"""
|
||||
计算当前批次样本的基线值
|
||||
|
||||
Args:
|
||||
- batch_size: 批次大小
|
||||
|
||||
Returns:
|
||||
- Tensor: 形状为 (batch_size,) 的基线值张量
|
||||
"""
|
||||
# Beta 分布的期望值作为基线
|
||||
baseline = self.alpha / (self.alpha + self.beta)
|
||||
return torch.full((batch_size,), baseline, dtype=torch.float32)
|
||||
|
||||
def compute_rho(self, cur_mean_logprob):
|
||||
"""
|
||||
计算自适应 rho 值
|
||||
|
||||
Args:
|
||||
- cur_mean_logprob: 当前平均对数概率
|
||||
|
||||
Returns:
|
||||
- float: rho 值, 介于 clip_lower 和 clip_upper 之间
|
||||
"""
|
||||
# 固定模式下返回常量
|
||||
if self.rho_mode == 'constant':
|
||||
return self.rho_const
|
||||
# 首次迭代时返回默认值
|
||||
if self.old_mean_logprob is None:
|
||||
return self.rho_const
|
||||
# 计算当前与上一轮的对数概率差值 (KL 散度的近似)
|
||||
kl = abs(self.old_mean_logprob - cur_mean_logprob)
|
||||
# 使用指数衰减公式计算 rho
|
||||
rho = 2 ** (-kl / self.D_half)
|
||||
# 裁剪到合理范围内
|
||||
return max(min(rho, self.clip_upper), self.clip_lower)
|
||||
|
||||
def update(self, rewards, cur_logprobs=None, response_masks=None):
|
||||
"""
|
||||
根据奖励值更新 Beta 分布参数
|
||||
|
||||
Args:
|
||||
- rewards: 奖励值张量, 形状为 (batch_size,)
|
||||
- cur_logprobs: 当前策略的对数概率 (可选)
|
||||
- response_masks: 响应序列的掩码 (可选)
|
||||
|
||||
Returns:
|
||||
- float: 本次更新使用的 rho 值
|
||||
"""
|
||||
# 如果提供了对数概率和掩码, 计算平均对数概率并更新 rho
|
||||
if cur_logprobs is not None and response_masks is not None:
|
||||
mean_logprob = ((cur_logprobs * response_masks).sum() / response_masks.sum()).item()
|
||||
rho = self.compute_rho(mean_logprob)
|
||||
@ -58,22 +107,50 @@ class AutoAdaptiveValueTracker:
|
||||
else:
|
||||
rho = self.rho_const
|
||||
|
||||
# 将奖励值归一化到 [0, 1] 区间, 与 Beta 分布匹配
|
||||
scale = 3.0
|
||||
normalized_rewards = (rewards + scale) / (2 * scale)
|
||||
avg_normalized_reward = normalized_rewards.mean().item()
|
||||
# 使用 rho 进行指数移动平均更新 Beta 分布参数
|
||||
self.alpha = rho * self.alpha + avg_normalized_reward
|
||||
self.beta = rho * self.beta + (1 - avg_normalized_reward)
|
||||
return rho
|
||||
|
||||
|
||||
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
"""整合所有奖励函数计算总奖励"""
|
||||
"""
|
||||
整合所有奖励函数计算总奖励
|
||||
|
||||
Args:
|
||||
- prompts: 输入提示列表, 字符串列表
|
||||
- responses: 模型生成的响应列表, 字符串列表
|
||||
- reward_model: 奖励模型, 用于评估响应质量
|
||||
- reward_tokenizer: 奖励模型的分词器
|
||||
|
||||
Returns:
|
||||
- Tensor: 每个响应的奖励值, 形状为 (batch_size,)
|
||||
"""
|
||||
def reasoning_model_reward(rewards):
|
||||
"""
|
||||
为推理模型计算格式奖励
|
||||
|
||||
检查响应是否符合 CoT (Chain of Thought) 格式要求:
|
||||
<think>...</think><answer>...</answer>
|
||||
|
||||
Args:
|
||||
- rewards: 基础奖励张量
|
||||
|
||||
Returns:
|
||||
- Tensor: 添加了格式奖励后的奖励张量
|
||||
"""
|
||||
# 定义两种合法的格式模式 (带换行或不带换行)
|
||||
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
|
||||
pattern2 = r"^<think>\n.*?\n</think>\n\n<answer>\n.*?\n</answer>$"
|
||||
# 检查每个响应是否匹配格式模式
|
||||
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
|
||||
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
|
||||
|
||||
# 格式完全正确时奖励 0.5 分
|
||||
format_rewards = []
|
||||
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
|
||||
if match_pattern or match_pattern2:
|
||||
@ -83,34 +160,44 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
rewards += torch.tensor(format_rewards, device=args.device)
|
||||
|
||||
def mark_num(text):
|
||||
"""计算响应中特殊标记的数量奖励"""
|
||||
reward = 0
|
||||
# 每个特殊标记正确出现一次奖励 0.25 分, 最多 1.0 分
|
||||
if text.count("<think>") == 1: reward += 0.25
|
||||
if text.count("</think>") == 1: reward += 0.25
|
||||
if text.count("<answer>") == 1: reward += 0.25
|
||||
if text.count("</answer>") == 1: reward += 0.25
|
||||
return reward
|
||||
|
||||
# 统计所有响应的标记奖励
|
||||
mark_rewards = [mark_num(response) for response in responses]
|
||||
rewards += torch.tensor(mark_rewards, device=args.device)
|
||||
return rewards
|
||||
|
||||
# 初始化奖励值为零
|
||||
rewards = torch.zeros(len(responses), device=args.device)
|
||||
# 如果启用了推理模式, 先计算格式奖励
|
||||
if args.reasoning == 1:
|
||||
rewards = reasoning_model_reward(rewards)
|
||||
|
||||
# 使用奖励模型评估响应质量 (不计算梯度)
|
||||
with torch.no_grad():
|
||||
reward_model_scores = []
|
||||
scale = 3.0
|
||||
|
||||
for i, (prompt, response) in enumerate(zip(prompts, responses)):
|
||||
# 从 prompt 中提取对话历史
|
||||
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
|
||||
matches = re.findall(pattern, prompt, re.DOTALL)
|
||||
messages = [{"role": role, "content": content.strip()} for role, content in matches]
|
||||
|
||||
# 将当前响应添加到对话历史中
|
||||
tmp_chat = messages + [{"role": "assistant", "content": response}]
|
||||
# 获取奖励模型评分并裁剪到合理范围
|
||||
score = reward_model.get_score(reward_tokenizer, tmp_chat)
|
||||
score = max(min(score, scale), -scale)
|
||||
|
||||
# 如果是推理模型, 额外评估 <answer> 部分的内容质量
|
||||
if args.reasoning == 1:
|
||||
answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
|
||||
if answer_match:
|
||||
@ -118,10 +205,12 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
|
||||
answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
|
||||
answer_score = max(min(answer_score, scale), -scale)
|
||||
# 综合评分: 整体响应占 40%, 答案内容占 60%
|
||||
score = score * 0.4 + answer_score * 0.6
|
||||
|
||||
reward_model_scores.append(score)
|
||||
|
||||
# 将奖励模型评分加入总奖励
|
||||
reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
|
||||
rewards += reward_model_scores
|
||||
|
||||
@ -129,14 +218,32 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
|
||||
|
||||
def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, value_tracker, start_step=0, wandb=None):
|
||||
"""
|
||||
SPO 训练的单 epoch 循环
|
||||
|
||||
Args:
|
||||
- epoch: 当前训练轮次
|
||||
- loader: 数据加载器
|
||||
- iters: 当前 epoch 的总步数
|
||||
- ref_model: 参考模型 (Policy 模型的旧版本)
|
||||
- reward_model: 奖励模型
|
||||
- reward_tokenizer: 奖励模型的分词器
|
||||
- value_tracker: 价值追踪器, 用于计算基线和更新
|
||||
- start_step: 起始步数 (用于断点续训)
|
||||
- wandb: Weights & Biases 日志对象 (可选)
|
||||
"""
|
||||
for step, batch in enumerate(loader, start=start_step + 1):
|
||||
# 从批次中获取提示文本
|
||||
prompts = batch['prompt'] # list[str], length B
|
||||
# 对提示进行分词和填充, 左填充以保留序列末尾的重要信息
|
||||
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
|
||||
padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
|
||||
# 裁剪提示长度, 只保留最后 max_seq_len 个 token
|
||||
if args.max_seq_len:
|
||||
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
|
||||
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
|
||||
|
||||
# 使用 Policy 模型生成响应 (不计算梯度)
|
||||
with torch.no_grad():
|
||||
# DDP 模型需要使用 .module 访问 generate 方法
|
||||
model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
@ -144,53 +251,82 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni
|
||||
**prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
|
||||
num_return_sequences=1, pad_token_id=tokenizer.pad_token_id) # [B, P+R]
|
||||
|
||||
# 提取生成的响应部分 (去掉提示部分)
|
||||
completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] # [B, R]
|
||||
|
||||
def get_per_token_logps(mdl, input_ids, n_keep):
|
||||
"""
|
||||
计算每个 token 的对数概率
|
||||
|
||||
Args:
|
||||
- mdl: 模型对象
|
||||
- input_ids: 输入 token ID 张量
|
||||
- n_keep: 需要计算对数概率的最后 n_keep 个 token
|
||||
|
||||
Returns:
|
||||
- Tensor: 形状为 (batch_size, n_keep) 的对数概率张量
|
||||
"""
|
||||
# 分离输入以避免梯度计算 (仅在推理模式下)
|
||||
input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
|
||||
# 前向传播获取 logits, 去掉最后一个位置 (因为没有下一个 token 作为标签)
|
||||
logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]
|
||||
per_token_logps = []
|
||||
# 对每个样本, 计算其在序列中每个 token 的对数概率
|
||||
for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
|
||||
ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row
|
||||
# 使用 gather 获取对应 token 的对数概率
|
||||
per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
|
||||
return torch.stack(per_token_logps)
|
||||
|
||||
# 计算 Policy 模型的对数概率和 MoE 辅助损失 (如果有)
|
||||
with autocast_ctx:
|
||||
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B, R]
|
||||
res = model(outputs) if lm_config.use_moe else None
|
||||
aux_loss = res.aux_loss if res is not None else torch.tensor(0.0, device=args.device)
|
||||
|
||||
# 计算参考模型的对数概率 (不计算梯度)
|
||||
with torch.no_grad():
|
||||
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B, R]
|
||||
|
||||
# 将 token ID 解码为文本响应
|
||||
completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True) # list[str], length B
|
||||
# 计算奖励值
|
||||
rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device) # [B]
|
||||
|
||||
# 获取基线值 (用于优势估计)
|
||||
baselines = value_tracker.get_baselines(len(prompts)).to(args.device) # [B]
|
||||
|
||||
scale = 3.0
|
||||
# Un-normalize baselines to be in the same scale as raw rewards [-3, 3]
|
||||
# 将归一化的基线值反归一化到原始奖励范围 [-3, 3]
|
||||
unnormalized_baselines = baselines * (2 * scale) - scale # [B]
|
||||
# 计算优势函数: 奖励 - 基线
|
||||
advantages = rewards - unnormalized_baselines # [B]
|
||||
|
||||
# 直接使用 baseline 提供的优势估计,只做裁剪防止梯度爆炸。不再做 batch 内归一化,因为 baseline 已经提供了跨 batch 的稳定基线
|
||||
# 直接使用 baseline 提供的优势估计, 只做裁剪防止梯度爆炸
|
||||
advantages = advantages.clamp(-5.0, 5.0)
|
||||
|
||||
# 构建 completion mask, 标记有效响应序列的位置
|
||||
is_eos = completion_ids == tokenizer.eos_token_id # [B, R]
|
||||
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device) # [B]
|
||||
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
||||
completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int() # [B, R]
|
||||
|
||||
# 计算 KL 散度 (Policy 与参考模型之间的差异)
|
||||
kl_div = ref_per_token_logps - per_token_logps # [B, R]
|
||||
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B, R]
|
||||
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B, R], 使用反向 KL 估计
|
||||
# 计算每个 token 的损失: 策略梯度 + KL 惩罚
|
||||
per_token_loss = -per_token_logps * advantages.unsqueeze(1) + args.beta * per_token_kl # [B, R]
|
||||
# 对响应长度取平均, 再对 batch 取平均得到策略损失
|
||||
policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
||||
# 总损失包含策略损失和辅助损失 (MoE 负载均衡损失), 除以累积步数用于梯度累积
|
||||
loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar
|
||||
loss.backward()
|
||||
|
||||
# 更新价值追踪器, 获取当前的 rho 值
|
||||
response_masks = completion_mask.float() # [B, R]
|
||||
rho = value_tracker.update(rewards, per_token_logps.detach(), response_masks)
|
||||
|
||||
# 梯度累积: 每 accumulation_steps 步更新一次模型参数
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
if args.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
@ -198,6 +334,7 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# 打印日志并记录到 wandb
|
||||
if step % args.log_interval == 0 or step == iters:
|
||||
policy_loss_val = loss.item() * args.accumulation_steps
|
||||
current_aux_loss = aux_loss.item()
|
||||
@ -224,19 +361,23 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni
|
||||
"learning_rate": current_lr
|
||||
})
|
||||
|
||||
# 定期保存模型检查点
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
model.eval()
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
# 获取模型状态字典并转为半精度, 保存到 CPU
|
||||
state_dict = raw_model.state_dict()
|
||||
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||
# 同时保存完整检查点 (包含优化器状态等)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
|
||||
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
# 清理本轮迭代的中间变量以释放内存
|
||||
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
|
||||
del completions, rewards, advantages, completion_mask, baselines, response_masks
|
||||
|
||||
@ -257,36 +398,43 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构 (0=否, 1=是)")
|
||||
parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
|
||||
parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
|
||||
parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数")
|
||||
parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型(0=普通模型,1=推理模型)')
|
||||
parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型 (0=普通模型, 1=推理模型)')
|
||||
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否, 1=是)")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-SPO", help="wandb项目名")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速 (0=否, 1=是)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
# ========== 1.初始化环境和随机种子 ==========
|
||||
local_rank = init_distributed_mode()
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
# ========== 2.配置目录、模型参数、检查 ckp ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
|
||||
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
|
||||
lm_config = MiniMindConfig(
|
||||
hidden_size=args.hidden_size,
|
||||
num_hidden_layers=args.num_hidden_layers,
|
||||
max_seq_len=args.max_seq_len + args.max_gen_len,
|
||||
use_moe=bool(args.use_moe)
|
||||
)
|
||||
# 如果启用续训, 检查是否存在检查点
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
# ========== 3.设置混合精度 ==========
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
# https://docs.pytorch.ac.cn/docs/stable/amp
|
||||
# 新版废弃: autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
# ========== 4.配置 wandb ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
@ -295,35 +443,37 @@ if __name__ == "__main__":
|
||||
wandb_run_name = f"MiniMind-SPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 初始化模型(Policy, Ref, Reward)和Value Tracker、数据 ==========
|
||||
# ========== 5.初始化模型(Policy, Ref, Reward)和 Value Tracker、数据 ==========
|
||||
base_weight = "reason" if args.reasoning == 1 else "full_sft"
|
||||
# Policy模型
|
||||
# Policy 模型 (要训练的模型)
|
||||
model, tokenizer = init_model(lm_config, base_weight, device=args.device)
|
||||
if args.use_compile == 1:
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
# Reference模型
|
||||
# Reference 模型 (固定参数, 用于计算 KL 散度)
|
||||
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
|
||||
ref_model = ref_model.eval().requires_grad_(False)
|
||||
# Reward模型
|
||||
# Reward 模型 (固定参数, 用于评估响应质量)
|
||||
reward_model = AutoModel.from_pretrained(
|
||||
args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True
|
||||
)
|
||||
reward_model = reward_model.to(args.device).eval().requires_grad_(False)
|
||||
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
|
||||
# Value Tracker
|
||||
# Value Tracker (自适应基线估计器)
|
||||
value_tracker = AutoAdaptiveValueTracker(rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96)
|
||||
|
||||
# 加载 RLAIF 数据集
|
||||
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
# 计算训练步数和学习率调度器
|
||||
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
|
||||
iters = len(loader_for_count)
|
||||
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
# ========== 6.从检查点恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
@ -332,15 +482,19 @@ if __name__ == "__main__":
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 7. DDP包模型 ==========
|
||||
# ========== 7.DDP 包装模型 ==========
|
||||
if dist.is_initialized():
|
||||
# 忽略位置编码相关的缓冲区, 避免 DDP 同步
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
# ========== 8.开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
# 分布式采样器设置当前 epoch, 确保不同进程使用不同的数据
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
# 为每个 epoch 设置不同的随机种子, 保证数据顺序不同
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
# 如果是续训的第一轮, 跳过已经训练过的步数
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
@ -350,5 +504,5 @@ if __name__ == "__main__":
|
||||
else:
|
||||
spo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, value_tracker, 0, wandb)
|
||||
|
||||
# ========== 9. 清理分布进程 ==========
|
||||
# ========== 9.清理分布式进程 ==========
|
||||
if dist.is_initialized(): dist.destroy_process_group()
|
||||
@ -1,40 +1,88 @@
|
||||
# 注:不建议再重复训练tokenizer(“词典”),MiniMind已自带,此脚本仅供学习和参考。基于不同词典训练的模型将导致输出完全不统一,降低社区的模型复用性
|
||||
# Note: It is not recommended to re-train the tokenizer. MiniMind already includes one. This script is for learning and reference only. Training models with different tokenizers will lead to inconsistent outputs and reduce model reusability in the community.
|
||||
# 注: 不建议再重复训练 tokenizer("词典"), MiniMind 已自带, 此脚本仅供学习和参考
|
||||
# 基于不同词典训练的模型将导致输出完全不统一, 降低社区的模型复用性
|
||||
# Note: It is not recommended to re-train the tokenizer. MiniMind already includes one. This script is for learning and reference only.
|
||||
# Training models with different tokenizers will lead to inconsistent outputs and reduce model reusability in the community.
|
||||
import os
|
||||
import json
|
||||
from tokenizers import decoders, models, pre_tokenizers, trainers, Tokenizer
|
||||
|
||||
# 训练数据文件路径
|
||||
DATA_PATH = '../dataset/pretrain_hq.jsonl'
|
||||
# 训练好的 tokenizer 保存目录
|
||||
TOKENIZER_DIR = '../model_learn_tokenizer/'
|
||||
# 词表大小, 控制 tokenizer 的词汇量
|
||||
VOCAB_SIZE = 6400
|
||||
|
||||
|
||||
def get_texts(data_path):
|
||||
"""
|
||||
从 JSONL 文件中逐行读取文本数据
|
||||
|
||||
这是一个生成器函数, 用于逐行读取数据文件并提取文本内容
|
||||
最多读取前 10000 行用于实验性训练
|
||||
|
||||
Args:
|
||||
- data_path: JSONL 格式的数据文件路径
|
||||
|
||||
Yields:
|
||||
- str: 每行数据中 'text' 字段的内容
|
||||
"""
|
||||
with open(data_path, 'r', encoding='utf-8') as f:
|
||||
for i, line in enumerate(f):
|
||||
if i >= 10000: break # 实验性,可只用前10000行测试
|
||||
# 实验性, 可只用前 10000 行测试
|
||||
if i >= 10000:
|
||||
break
|
||||
data = json.loads(line)
|
||||
yield data['text']
|
||||
|
||||
|
||||
def train_tokenizer(data_path, tokenizer_dir, vocab_size):
|
||||
"""
|
||||
训练 BPE(Byte Pair Encoding) tokenizer
|
||||
|
||||
使用 Hugging Face tokenizers 库训练自定义的 BPE tokenizer
|
||||
包括配置预分词器、训练器、保存词表和配置文件等步骤
|
||||
|
||||
Args:
|
||||
- data_path: 训练数据文件路径
|
||||
- tokenizer_dir: tokenizer 保存目录
|
||||
- vocab_size: 目标词表大小
|
||||
"""
|
||||
# 创建 BPE 模型的 tokenizer 实例
|
||||
tokenizer = Tokenizer(models.BPE())
|
||||
# 设置 ByteLevel 预分词器, 不添加前缀空格
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
||||
|
||||
# 配置 BPE 训练器
|
||||
trainer = trainers.BpeTrainer(
|
||||
vocab_size=vocab_size,
|
||||
# 定义特殊 token: 文本结束符、指令开始符、指令结束符
|
||||
special_tokens=["<|endoftext|>", "<|im_start|>", "<|im_end|>"],
|
||||
show_progress=True,
|
||||
# 使用 ByteLevel 预分词器的初始字母表
|
||||
initial_alphabet=pre_tokenizers.ByteLevel.alphabet()
|
||||
)
|
||||
|
||||
# 获取训练文本数据
|
||||
texts = get_texts(data_path)
|
||||
# 使用迭代器训练 tokenizer
|
||||
tokenizer.train_from_iterator(texts, trainer=trainer)
|
||||
# 设置 ByteLevel 解码器
|
||||
tokenizer.decoder = decoders.ByteLevel()
|
||||
|
||||
# 验证特殊 token 的 ID 是否符合预期
|
||||
assert tokenizer.token_to_id("<|endoftext|>") == 0
|
||||
assert tokenizer.token_to_id("<|im_start|>") == 1
|
||||
assert tokenizer.token_to_id("<|im_end|>") == 2
|
||||
|
||||
# 创建保存目录 (如果不存在)
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
# 保存完整的 tokenizer 配置
|
||||
tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||
# 保存 BPE 模型文件 (vocab.json 和 merges.txt)
|
||||
tokenizer.model.save(tokenizer_dir)
|
||||
|
||||
# 构建 Hugging Face 格式的 tokenizer 配置
|
||||
config = {
|
||||
"add_bos_token": False,
|
||||
"add_eos_token": False,
|
||||
@ -79,48 +127,68 @@ def train_tokenizer(data_path, tokenizer_dir, vocab_size):
|
||||
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' -%}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else -%}\n {{- '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}"
|
||||
}
|
||||
|
||||
# 保存 tokenizer 配置文件
|
||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=4)
|
||||
print("Tokenizer training completed.")
|
||||
|
||||
|
||||
def eval_tokenizer(tokenizer_dir):
|
||||
"""
|
||||
评估和测试训练好的 tokenizer
|
||||
|
||||
加载训练好的 tokenizer, 使用对话模板生成 prompt,
|
||||
并测试编码、解码的一致性和流式解码效果。
|
||||
|
||||
Args:
|
||||
- tokenizer_dir: tokenizer 模型目录路径
|
||||
"""
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# 加载训练好的 tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
|
||||
|
||||
# 构造测试对话消息
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个优秀的聊天机器人,总是给我正确的回应!"},
|
||||
{"role": "user", "content": '你来自哪里?'},
|
||||
{"role": "system", "content": "你是一个优秀的聊天机器人, 总是给我正确的回应!"},
|
||||
{"role": "user", "content": '你来自哪里?'},
|
||||
{"role": "assistant", "content": '我来自地球'}
|
||||
]
|
||||
|
||||
# 应用对话模板生成 prompt
|
||||
new_prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False
|
||||
)
|
||||
print('-'*100)
|
||||
print('-' * 100)
|
||||
print(new_prompt)
|
||||
|
||||
|
||||
print('-'*100)
|
||||
print('tokenizer词表长度:', len(tokenizer))
|
||||
print('-' * 100)
|
||||
# 打印 tokenizer 词表长度
|
||||
print('tokenizer 词表长度:', len(tokenizer))
|
||||
# 对 prompt 进行编码
|
||||
model_inputs = tokenizer(new_prompt)
|
||||
print('encoder长度:', len(model_inputs['input_ids']))
|
||||
print('encoder 长度:', len(model_inputs['input_ids']))
|
||||
# 解码并验证一致性
|
||||
response = tokenizer.decode(model_inputs['input_ids'], skip_special_tokens=False)
|
||||
print('decoder一致性:', response == new_prompt, "\n")
|
||||
print('decoder 一致性:', response == new_prompt, "\n")
|
||||
|
||||
|
||||
print('-'*100)
|
||||
print('流式解码(字节缓冲)测试:')
|
||||
print('-' * 100)
|
||||
print('流式解码(字节缓冲)测试:')
|
||||
input_ids = model_inputs['input_ids']
|
||||
token_cache = []
|
||||
for tid in input_ids:
|
||||
token_cache.append(tid)
|
||||
# 尝试解码当前缓存的 token
|
||||
current_decode = tokenizer.decode(token_cache)
|
||||
# 当解码结果有效且不包含 Unicode 替换字符时输出
|
||||
if current_decode and '\ufffd' not in current_decode:
|
||||
display_ids = token_cache[0] if len(token_cache) == 1 else token_cache
|
||||
raw_tokens = [tokenizer.convert_ids_to_tokens(int(t)) for t in (token_cache if isinstance(token_cache, list) else [token_cache])]
|
||||
print(f'Token ID: {str(display_ids):15} -> Raw: {str(raw_tokens):20} -> Decode Str: {current_decode}')
|
||||
token_cache = []
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_tokenizer(DATA_PATH, TOKENIZER_DIR, VOCAB_SIZE)
|
||||
eval_tokenizer(TOKENIZER_DIR)
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
"""
|
||||
训练工具函数集合
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 设置包名
|
||||
__package__ = "trainer"
|
||||
# 将项目根目录添加到系统路径, 确保能够导入同级目录的模块
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
@ -16,64 +17,194 @@ from transformers import AutoTokenizer
|
||||
from model.model_minimind import MiniMindForCausalLM
|
||||
|
||||
def get_model_params(model, config):
|
||||
"""
|
||||
计算并打印模型参数量统计信息
|
||||
|
||||
Args:
|
||||
- model: 待统计的 PyTorch 模型
|
||||
- config: 模型配置对象, 包含 use_moe, n_routed_experts, num_experts, num_experts_per_tok, n_shared_experts 等属性
|
||||
|
||||
Returns:
|
||||
- 无返回值, 会在主进程打印模型参数量信息
|
||||
"""
|
||||
# 计算模型总参数量, 单位为百万 (M)
|
||||
total = sum(p.numel() for p in model.parameters()) / 1e6
|
||||
|
||||
# 从配置中获取 MoE 相关的专家数量配置
|
||||
# n_routed: 路由专家总数, n_active: 每个 token 激活的专家数, n_shared: 共享专家数
|
||||
n_routed = getattr(config, 'n_routed_experts', getattr(config, 'num_experts', 0))
|
||||
n_active = getattr(config, 'num_experts_per_tok', 0)
|
||||
n_shared = getattr(config, 'n_shared_experts', 0)
|
||||
|
||||
# 计算路由专家 (routed experts) 的参数量, 单位为百万 (M)
|
||||
# 通过匹配参数名中包含 'mlp.experts.0.' 的参数来统计
|
||||
expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.experts.0.' in n) / 1e6
|
||||
|
||||
# 计算共享专家 (shared experts) 的参数量, 单位为百万 (M)
|
||||
# 通过匹配参数名中包含 'mlp.shared_experts.0.' 的参数来统计
|
||||
shared_expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.shared_experts.0.' in n) / 1e6
|
||||
|
||||
# 计算基础参数量 (非专家部分的参数)
|
||||
base = total - (expert * n_routed) - (shared_expert * n_shared)
|
||||
|
||||
# 计算激活的参数量 (基础参数 + 激活的专家参数)
|
||||
active = base + (expert * n_active) + (shared_expert * n_shared)
|
||||
if active < total: Logger(f'Model Params: {total:.2f}M-A{active:.2f}M')
|
||||
else: Logger(f'Model Params: {total:.2f}M')
|
||||
|
||||
# 如果存在未被激活的专家参数, 打印总参数量和激活参数量, 否则只打印总参数量
|
||||
if active < total:
|
||||
Logger(f'Model Params: {total:.2f}M-A{active:.2f}M')
|
||||
else:
|
||||
Logger(f'Model Params: {total:.2f}M')
|
||||
|
||||
|
||||
def is_main_process():
|
||||
"""
|
||||
判断当前进程是否为主进程
|
||||
|
||||
在分布式训练环境中, 只有 rank 为 0 的进程是主进程
|
||||
如果分布式环境未初始化, 也认为当前是主进程
|
||||
|
||||
Returns:
|
||||
- bool: 如果是主进程返回 True, 否则返回 False
|
||||
"""
|
||||
return not dist.is_initialized() or dist.get_rank() == 0
|
||||
|
||||
|
||||
def Logger(content):
|
||||
"""
|
||||
线程安全的日志打印函数
|
||||
|
||||
只能在主进程 (rank 0) 打印日志, 避免分布式训练时日志重复输出
|
||||
|
||||
Args:
|
||||
- content: 要打印的内容, 可以是字符串或任何可打印对象
|
||||
"""
|
||||
if is_main_process():
|
||||
print(content)
|
||||
|
||||
|
||||
def get_lr(current_step, total_steps, lr):
|
||||
"""
|
||||
计算当前学习率, 使用余弦退火策略 (Cosine Annealing)
|
||||
|
||||
学习率变化曲线: 从 0.55*lr 开始, 经过余弦变化, 最终回到 0.1*lr
|
||||
公式: lr * (0.1 + 0.45 * (1 + cos(pi * current_step / total_steps)))
|
||||
|
||||
Args:
|
||||
- current_step: 当前训练步数 (从 0 开始)
|
||||
- total_steps: 总训练步数
|
||||
- lr: 基础学习率
|
||||
|
||||
Returns:
|
||||
- float: 当前步的学习率值
|
||||
"""
|
||||
return lr*(0.1 + 0.45*(1 + math.cos(math.pi * current_step / total_steps)))
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
if int(os.environ.get("RANK", -1)) == -1:
|
||||
return 0 # 非DDP模式
|
||||
"""
|
||||
初始化分布式训练模式
|
||||
|
||||
检查环境变量 RANK 是否存在来判断是否使用分布式训练
|
||||
如果未使用分布式训练, 返回 0 表示非 DDP 模式
|
||||
如果使用分布式训练, 初始化 NCCL 后端, 设置本地 GPU 设备
|
||||
|
||||
Returns:
|
||||
- int: 本地 rank 编号, 非 DDP 模式下返回 0
|
||||
"""
|
||||
# 检查是否需要初始化分布式训练
|
||||
# 如果环境变量 RANK 不存在或为 -1, 则表示不使用分布式训练
|
||||
if int(os.environ.get("RANK", -1)) == -1:
|
||||
return 0 # 非 DDP 模式
|
||||
|
||||
# 初始化分布式训练进程组, 使用 NCCL 后端 (适用于 GPU 通信)
|
||||
dist.init_process_group(backend="nccl")
|
||||
|
||||
# 从环境变量获取本地 rank, 并设置当前进程使用的 GPU 设备
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
return local_rank
|
||||
|
||||
|
||||
def setup_seed(seed: int):
|
||||
"""
|
||||
设置所有随机数种子, 确保实验可复现
|
||||
- 设置 Python random、numpy、torch 的随机种子
|
||||
- 以及 CUDA 相关 的随机种子和确定性设置
|
||||
|
||||
Args:
|
||||
- seed: 随机数种子值, 建议使用整数
|
||||
"""
|
||||
# 设置 Python 内置随机数模块的种子
|
||||
random.seed(seed)
|
||||
# 设置 NumPy 的随机数种子
|
||||
np.random.seed(seed)
|
||||
# 设置 PyTorch CPU 计算的随机种子
|
||||
torch.manual_seed(seed)
|
||||
# 设置 PyTorch 单个 GPU 的随机种子
|
||||
torch.cuda.manual_seed(seed)
|
||||
# 设置 PyTorch 所有 GPU 的随机种子 (多卡训练时需要)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
# 设置 cuDNN 为确定性模式, 禁用优化以确保可复现
|
||||
# 这会降低训练速度, 但能确保结果可复现
|
||||
torch.backends.cudnn.deterministic = True
|
||||
# 禁用 cuDNN 的自动调优功能, 避免因算法选择导致的不可复现
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoch=0, step=0, wandb=None, save_dir='../checkpoints', **kwargs):
|
||||
"""
|
||||
模型检查点保存与加载函数
|
||||
|
||||
支持两种模式:
|
||||
1. 保存模式: 当 model 参数不为 None 时, 保存模型权重、优化器状态、训练进度等信息
|
||||
2. 加载模式: 当 model 参数为 None 时, 从检查点文件恢复训练状态
|
||||
|
||||
Args:
|
||||
- lm_config: 模型配置对象, 包含 use_moe, hidden_size 等属性
|
||||
- weight: 检查点名称前缀, 默认为 'full_sft'
|
||||
- model: 待保存的模型对象, 为 None 时表示加载模式
|
||||
- optimizer: 优化器对象, 用于保存优化器状态
|
||||
- epoch: 当前训练轮次
|
||||
- step: 当前训练步数
|
||||
- wandb: Weights & Biases 日志对象, 用于获取运行 ID
|
||||
- save_dir: 检查点保存目录
|
||||
- **kwargs: 额外的可保存对象 (如 lr_scheduler 等), 需要具备 state_dict 方法
|
||||
|
||||
Returns:
|
||||
- dict 或 None: 加载模式下返回检查点数据字典, 保存模式下返回 None
|
||||
"""
|
||||
# 确保保存目录存在
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# 构建检查点文件名
|
||||
# 根据是否使用 MoE 模型添加不同的后缀
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
|
||||
# 完整模型权重文件路径
|
||||
ckp_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}.pth'
|
||||
|
||||
# 恢复训练所需的完整检查点文件路径 (包含优化器状态等)
|
||||
resume_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}_resume.pth'
|
||||
|
||||
# ===== 保存模式 =====
|
||||
if model is not None:
|
||||
# 获取原始模型 (如果是 DDP 包装的模型, 需要解包)
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
# 处理 PEFT/Lora 等场景下的模型 (使用 _orig_mod 属性获取原始模型)
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
|
||||
# 获取模型状态字典, 并将权重转换为半精度 (FP16) 后转移到 CPU
|
||||
state_dict = raw_model.state_dict()
|
||||
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||||
|
||||
# 使用临时文件避免保存过程中断导致文件损坏
|
||||
ckp_tmp = ckp_path + '.tmp'
|
||||
torch.save(state_dict, ckp_tmp)
|
||||
os.replace(ckp_tmp, ckp_path)
|
||||
|
||||
# 获取 Weights & Biases 的运行 ID
|
||||
wandb_id = None
|
||||
if wandb:
|
||||
if hasattr(wandb, 'get_run'):
|
||||
@ -82,16 +213,20 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
|
||||
else:
|
||||
wandb_id = getattr(wandb, 'id', None)
|
||||
|
||||
# 构建恢复训练所需的完整数据字典
|
||||
resume_data = {
|
||||
'model': state_dict,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'epoch': epoch,
|
||||
'step': step,
|
||||
'world_size': dist.get_world_size() if dist.is_initialized() else 1,
|
||||
'wandb_id': wandb_id
|
||||
'model': state_dict, # 模型权重
|
||||
'optimizer': optimizer.state_dict(), # 优化器状态
|
||||
'epoch': epoch, # 当前训练轮次
|
||||
'step': step, # 当前训练步数
|
||||
'world_size': dist.get_world_size() if dist.is_initialized() else 1, # 分布式训练 world size
|
||||
'wandb_id': wandb_id # W&B 运行 ID
|
||||
}
|
||||
|
||||
# 处理额外的可保存对象 (如学习率调度器等)
|
||||
for key, value in kwargs.items():
|
||||
if value is not None:
|
||||
# 如果对象是模型 (可能是 DDP 包装的), 先解包再获取 state_dict
|
||||
if hasattr(value, 'state_dict'):
|
||||
raw_value = value.module if isinstance(value, DistributedDataParallel) else value
|
||||
raw_value = getattr(raw_value, '_orig_mod', raw_value)
|
||||
@ -99,59 +234,142 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
|
||||
else:
|
||||
resume_data[key] = value
|
||||
|
||||
# 保存完整检查点到临时文件, 然后原子替换
|
||||
resume_tmp = resume_path + '.tmp'
|
||||
torch.save(resume_data, resume_tmp)
|
||||
os.replace(resume_tmp, resume_path)
|
||||
|
||||
# 清理内存
|
||||
del state_dict, resume_data
|
||||
torch.cuda.empty_cache()
|
||||
else: # 加载模式
|
||||
|
||||
# ===== 加载模式 =====
|
||||
else:
|
||||
# 如果检查点文件存在, 加载恢复数据
|
||||
if os.path.exists(resume_path):
|
||||
ckp_data = torch.load(resume_path, map_location='cpu')
|
||||
|
||||
# 处理 GPU 数量变化的情况, 自动调整 step 值
|
||||
saved_ws = ckp_data.get('world_size', 1)
|
||||
current_ws = dist.get_world_size() if dist.is_initialized() else 1
|
||||
|
||||
if saved_ws != current_ws:
|
||||
# 当 GPU 数量变化时, 按比例调整 step
|
||||
ckp_data['step'] = ckp_data['step'] * saved_ws // current_ws
|
||||
Logger(f'GPU数量变化({saved_ws}→{current_ws}),step已自动转换为{ckp_data["step"]}')
|
||||
Logger(f'GPU 数量变化 ({saved_ws}→{current_ws}), step 已自动转换为 {ckp_data["step"]}')
|
||||
|
||||
return ckp_data
|
||||
|
||||
# 如果检查点文件不存在, 返回 None
|
||||
return None
|
||||
|
||||
|
||||
def init_model(lm_config, from_weight='pretrain', tokenizer_path='../model', save_dir='../out', device='cuda'):
|
||||
"""
|
||||
初始化模型和分词器
|
||||
|
||||
加载 Hugging Face 格式的分词器, 创建 MiniMind 模型实例
|
||||
可选择从预训练权重加载模型参数
|
||||
|
||||
Args:
|
||||
- lm_config: 模型配置对象
|
||||
- from_weight: 预训练权重名称前缀, 默认为 'pretrain', 设置为 'none' 表示不加载权重
|
||||
- tokenizer_path: 分词器模型目录路径
|
||||
- save_dir: 权重文件保存目录
|
||||
- device: 模型运行设备, 默认为 'cuda'
|
||||
|
||||
Returns:
|
||||
- tuple: (model, tokenizer) 模型和分词器对象
|
||||
"""
|
||||
# 从预训练目录加载分词器
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
# 创建 MiniMind 因果语言模型实例
|
||||
model = MiniMindForCausalLM(lm_config)
|
||||
|
||||
if from_weight!= 'none':
|
||||
# 如果指定了加载权重, 从检查点文件加载模型参数
|
||||
if from_weight != 'none':
|
||||
# 根据是否使用 MoE 构建权重文件名
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
weight_path = f'{save_dir}/{from_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
|
||||
# 加载权重到 CPU, 然后加载到模型
|
||||
weights = torch.load(weight_path, map_location=device)
|
||||
model.load_state_dict(weights, strict=False)
|
||||
|
||||
# 打印模型参数量统计信息
|
||||
get_model_params(model, lm_config)
|
||||
|
||||
# 打印可训练参数量 (需要计算梯度)
|
||||
Logger(f'Trainable Params: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f}M')
|
||||
|
||||
# 将模型移动到指定设备并返回
|
||||
return model.to(device), tokenizer
|
||||
|
||||
|
||||
class SkipBatchSampler(Sampler):
|
||||
"""
|
||||
可跳过指定数量批次的采样器
|
||||
|
||||
继承自 PyTorch Sampler, 在训练过程中跳过指定数量的初始批次
|
||||
常用于从检查点恢复训练时, 跳过已经训练过的批次
|
||||
"""
|
||||
|
||||
def __init__(self, sampler, batch_size, skip_batches=0):
|
||||
"""
|
||||
初始化 SkipBatchSampler
|
||||
|
||||
Args:
|
||||
- sampler: 基础采样器, 提供数据索引序列
|
||||
- batch_size: 每个批次的样本数量
|
||||
- skip_batches: 要跳过的批次数量, 默认为 0 (不跳过)
|
||||
"""
|
||||
self.sampler = sampler
|
||||
self.batch_size = batch_size
|
||||
self.skip_batches = skip_batches
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
生成批次索引迭代器
|
||||
|
||||
遍历基础采样器的索引, 组装成批次
|
||||
跳过指定数量的初始批次
|
||||
|
||||
Yields:
|
||||
- list: 批次索引列表, 每个列表包含 batch_size 个样本索引
|
||||
"""
|
||||
batch = []
|
||||
skipped = 0
|
||||
|
||||
# 遍历基础采样器的所有索引
|
||||
for idx in self.sampler:
|
||||
batch.append(idx)
|
||||
|
||||
# 当批次满时
|
||||
if len(batch) == self.batch_size:
|
||||
# 如果还有需要跳过的批次, 跳过当前批次
|
||||
if skipped < self.skip_batches:
|
||||
skipped += 1
|
||||
batch = []
|
||||
continue
|
||||
|
||||
# 产出当前批次
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
# 处理最后一个不完整的批次
|
||||
if len(batch) > 0 and skipped >= self.skip_batches:
|
||||
yield batch
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
返回有效批次的总数量
|
||||
|
||||
Returns:
|
||||
- int: 跳过 skip_batches 后的可用批次数量
|
||||
"""
|
||||
# 计算总批次数, 向上取整
|
||||
total_batches = (len(self.sampler) + self.batch_size - 1) // self.batch_size
|
||||
|
||||
# 返回跳过 skip_batches 后的批次数量, 不能为负数
|
||||
return max(0, total_batches - self.skip_batches)
|
||||
Loading…
Reference in New Issue
Block a user