This commit is contained in:
Bader 2026-02-07 05:40:43 +08:00 committed by GitHub
commit 0305628b3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 2243 additions and 463 deletions

View File

@ -2,52 +2,124 @@ import torch
from torch import optim, nn
# 定义Lora网络结构
class LoRA(nn.Module):
"""
Lora 网络结构
"""
def __init__(self, in_features, out_features, rank):
super().__init__()
self.rank = rank # LoRA的秩rank控制低秩矩阵的大小
self.A = nn.Linear(in_features, rank, bias=False) # 低秩矩阵A
self.B = nn.Linear(rank, out_features, bias=False) # 低秩矩阵B
# 矩阵A高斯初始化
# LoRA 的秩 (rank)
self.rank = rank
# 低秩矩阵 A & B
self.A = nn.Linear(in_features, rank, bias=False)
self.B = nn.Linear(rank, out_features, bias=False)
# 低秩矩阵初始化
# A: 高斯分布
# B: 全零
self.A.weight.data.normal_(mean=0.0, std=0.02)
# 矩阵B全0初始化
self.B.weight.data.zero_()
def forward(self, x):
# 前向传播
return self.B(self.A(x))
def apply_lora(model, rank=8):
"""
对指定的 model 添加 lora
"""
# 遍历模型的所有模块
for name, module in model.named_modules():
# 检查是否为 "线性层" 且 "权重矩阵为方阵"
if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]:
# 创建 LoRA 层并移动到模型设备上
lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(model.device)
# 针对 model 设定一个属性值: 属性: "lora" / 值: lora
# https://docs.python.org/zh-cn/3.14/library/functions.html#setattr
setattr(module, "lora", lora)
# 保存原始前向传播函数
original_forward = module.forward
# 显式绑定
def forward_with_lora(x, layer1=original_forward, layer2=lora):
# 定义新的前向传播函数, 用于在原始前向传播基础上加入 LoRA 层
# 这里显示绑定了 layer1=original_forward, layer2=lora 这两个参数
def forward_with_lora(x, layer1=original_forward, layer2=lora):
# 返回原始前向传播结果 & 加上 LoRA 层的输出
return layer1(x) + layer2(x)
# 替换模块的前向传播函数
module.forward = forward_with_lora
def load_lora(model, path):
state_dict = torch.load(path, map_location=model.device)
state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()}
"""
加载 LoRA 权重文件并应用到模型
:param model: 包含初始化的 LoRA 层的模型对象, 即不包含训练权重, 只包含结构和空的 LoRA 参数
:param path: LoRA 权重文件路径, 包含所有 A/B 矩阵的 state_dict
"""
# 从文件加载 state_dict, 从而获得完整的模型权重
state_dict = torch.load(path, map_location=model.device)
# 处理模型权重的键名, 移除 state_dict 中所有键名前的 "module." 前缀
# 如果使用了 DP/DDP 这样的数据并行, 那么 pytorch 就会在模型的权重字典中添加 "module." 前缀来表示这些
# 如果没有用 DP/DDP 来包装模型, 那么在使用 model.load_state_dict() 时就会报错
# 因此这里移除掉 state_dict 中每层的 "module." 前缀, 因为这里 lora 微调时不会采用多卡
state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()}
# #### 代码等效 (推荐字典推导式, 这里仅作理解) ####
# tmp_state_dict = {}
# for k, v in state_dict.items():
# if k.startswith('module.'):
# # 移除前缀 "module."
# k = k[7:]
# tmp_state_dict[k] = v
# state_dict = tmp_state_dict
# ##############################################
# 遍历 model 的所有模块, 并用 state_dict 中的 数值替换 model 中的数值 (针对 lora 层)
for name, module in model.named_modules():
# 检查模块是否有 lora 属性
# https://docs.python.org/zh-cn/3.14/library/functions.html#hasattr
if hasattr(module, 'lora'):
# 提取与当前模块相关的 LoRA 状态
lora_state = {k.replace(f'{name}.lora.', ''): v for k, v in state_dict.items() if f'{name}.lora.' in k}
# 加载 LoRA 层的状态字典
module.lora.load_state_dict(lora_state)
def save_lora(model, path):
"""
仅保存模型中 LoRA 适配器的权重参数, 不保存主干模型
此函数适用于使用 Hugging Face PEFT 库进行 LoRA 微调的场景
:param model: 训练完成的模型对象 (可以是 PeftModel 或原始 nn.Module)
:param path: 保存路径
"""
# 获取原始模型; 若 model 是 PEFT 包装后的 PeftModel
# 则通过 _orig_mod 获取其封装的原始模型; 否则直接使用当前模型
# https://docs.python.org/zh-cn/3.14/library/functions.html#getattr
# getattr(object, name, default):
# 如果对象有这个属性, 就返回它; 没有就返回默认值
raw_model = getattr(model, '_orig_mod', model)
# 在使用 Hugging Face 的 PEFT 库时, PEFT 会 "包装" 原始模型
# 此时的 model 不再是原来的 AutoModelForCausalLM, 而是一个 PeftModel 类的对象
# 它内部有一个属性 _orig_mod, 用来指向原始模型
# 因此, 这行代码的实际作用如下:
# 如果 model 是一个 PEFT 包装过的, 那么就取原始模型
# 如果 model 没有经过 PEFT 包装, 则直接使用当前模型
# 创建一个空的状态字典用于存储模型
state_dict = {}
for name, module in raw_model.named_modules():
# 检查模块是否有 lora 属性
if hasattr(module, 'lora'):
# 处理模块名称,去除 "module." 缀
# 如果使用了 DP/DDP 这样的数据并行, 那么 pytorch 就会在模型的权重字典中添加 "module." 前缀来表示这些
# 如果没有用 DP/DDP 来包装模型, 那么在使用 model.load_state_dict() 时就会报错
clean_name = name[7:] if name.startswith("module.") else name
# 构建 LoRA 层 的 lora_state
lora_state = {f'{clean_name}.lora.{k}': v for k, v in module.lora.state_dict().items()}
# 更新状态字典
state_dict.update(lora_state)
torch.save(state_dict, path)
# 保存状态字典到文件 (只保存 lora 层的权重)
torch.save(state_dict, path)

View File

@ -1,11 +1,13 @@
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
# MiniMind Config
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
#########################################
# MiniMind Config
#########################################
from transformers import PretrainedConfig
class MiniMindConfig(PretrainedConfig):
"""
MiniMind Config, 继承自 HuggingFace PretrainedConfig
用于设置和管理模型的各种超参数和结构设置
"""
model_type = "minimind"
def __init__(
@ -15,30 +17,30 @@ class MiniMindConfig(PretrainedConfig):
eos_token_id: int = 2,
hidden_act: str = 'silu',
hidden_size: int = 512,
intermediate_size: int = None,
intermediate_size: int = None, # FFN 中间层大小, 推荐不用设置, 会自动计算
max_position_embeddings: int = 32768,
num_attention_heads: int = 8,
num_attention_heads: int = 8, # 注意力头数, 也是 Query 的头数
num_hidden_layers: int = 8,
num_key_value_heads: int = 2,
num_key_value_heads: int = 2, # Key / Value 的头数, 如果未设定则等于 num_attention_heads
vocab_size: int = 6400,
rms_norm_eps: float = 1e-05,
rope_theta: int = 1000000.0,
inference_rope_scaling: bool = False,
flash_attn: bool = True,
####################################################
# Here are the specific configurations of MOE
# When use_moe is false, the following is invalid
# MOE 相关配置
# 当 use_moe = false 时, 以下配置将无效
####################################################
use_moe: bool = False,
num_experts_per_tok: int = 2,
n_routed_experts: int = 4,
n_shared_experts: int = 1,
scoring_func: str = 'softmax',
aux_loss_alpha: float = 0.01,
seq_aux: bool = True,
norm_topk_prob: bool = True,
num_experts_per_tok: int = 2, # 每个 token 选择的专家数量
n_routed_experts: int = 4, # 总的专家数量
n_shared_experts: int = 1, # 共享专家
scoring_func: str = 'softmax', # 评分函数, 默认 'softmax'
aux_loss_alpha: float = 0.01, # 辅助损失的 alpha 参数
seq_aux: bool = True, # 是否在序列级别上计算辅助损失
norm_topk_prob: bool = True, # 是否标准化 top-k 概率, 推荐启用
**kwargs
):
):
super().__init__(**kwargs)
self.dropout = dropout
self.bos_token_id = bos_token_id
@ -65,28 +67,29 @@ class MiniMindConfig(PretrainedConfig):
} if self.inference_rope_scaling else None
self.flash_attn = flash_attn
####################################################
# Here are the specific configurations of MOE
# When use_moe is false, the following is invalid
# MOE 相关配置
# 当 use_moe = false 时, 以下配置将无效
####################################################
self.use_moe = use_moe
self.num_experts_per_tok = num_experts_per_tok # 每个token选择的专家数量
self.n_routed_experts = n_routed_experts # 总的专家数量
self.n_shared_experts = n_shared_experts # 共享专家
self.scoring_func = scoring_func # 评分函数,默认为'softmax'
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
self.num_experts_per_tok = num_experts_per_tok # 每个 token 选择的专家数量
self.n_routed_experts = n_routed_experts # 总的专家数量
self.n_shared_experts = n_shared_experts # 共享专家
self.scoring_func = scoring_func # 评分函数, 默认 'softmax'
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的 alpha 参数
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
self.norm_topk_prob = norm_topk_prob # 是否标准化 top-k 概率, 推荐启用
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
# MiniMind Model
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
#########################################
# MiniMind Model
#########################################
import math
import torch
import torch.nn.init as init
import torch.nn.functional as F
from torch import nn
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py
from transformers.activations import ACT2FN
from typing import Optional, Tuple, List, Union
from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
@ -94,20 +97,29 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module):
"""
RMSNorm (Root Mean Square Normalization) 标准化层
"""
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
# 数值稳定性的小常数
self.eps = eps
# 可学习的权重参数
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
# 计算 RMS (Root Mean Square) 标准化
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
# 应用 RMS 标准化并乘以权重
return self.weight * self._norm(x.float()).type_as(x)
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6,
rope_scaling: Optional[dict] = None):
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, rope_scaling: Optional[dict] = None):
"""
预计算 Rotary Position Embedding (RoPE) 的频率
"""
freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
if rope_scaling is not None:
orig_max, factor, beta_fast, beta_slow, attn_factor = (
@ -116,6 +128,7 @@ def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float =
)
if end / orig_max > 1.0:
# YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp
# YaRN 缩放公式, 用于扩展位置编码的有效长度
inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1)
@ -129,7 +142,11 @@ def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float =
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""
应用 Rotary Position Embedding (RoPE) query key
"""
def rotate_half(x):
# 将张量的后半部分移到前半部分, 实现旋转
return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))
@ -138,24 +155,57 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
"""
重复扩充 key/value 头的 dim, 使其数量与 query 的头数对齐, 用来支持 GQA (Grouped Query Attention)
torch.repeat_interleave(x, dim=2, repeats=n_rep)
https://docs.pytorch.ac.cn/docs/stable/generated/torch.repeat_interleave.html
:param: x: 一批 tensor 数据, shape (bs, seq_len, num_key_value_heads, head_dim)
:param: n_rep: 重复次数; 如果 query 32 个头, key/value 只有 4 个头, n_rep = 8
"""
# 获取 x 的形状, 并分别赋予 bs, seq_len, num_key_value_heads, head_dim
bs, slen, num_key_value_heads, head_dim = x.shape
# n_rep = 1, 则不重复, 直接返回原数据
if n_rep == 1:
return x
return (
x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
)
# n_rep != 1, 则重复 kv 指定次数
else:
return x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
# #### 等效代码 ####
# # 在第三维上增加一个维度, 即 (bs, slen, num_key_value_heads, head_dim) -> (bs, slen, num_key_value_heads, 1, head_dim)
# x = x[:, :, :, None, :]
# # 将这个新增的维度扩展成 n_rep 大小 -> (bs, slen, num_key_value_heads, n_rep, head_dim)
# x = x.expand(bs, slen, num_key_value_heads, n_rep, head_dim)
# # 将维度重新调整为 -> (bs, slen, num_key_value_heads * n_rep, head_dim)
# # 注意这里必须是 num_key_value_heads * n_rep, 不能反过来
# # num_key_value_heads * n_rep 意味着 -> [kv0, kv0, kv1, kv1], 如果反过来就变为 -> [kv0, kv1, kv0, kv1]
# # 这会导致 query0 找到了 kv0, 但同组的 query1 却跑去匹配了 kv1, 这违背了 GQA 的基本原理
# x = x.reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
# # 返回 x
# return x
# #################
class Attention(nn.Module):
"""
Attention 模块, 支持 GQA (Grouped Query Attention)
"""
def __init__(self, args: MiniMindConfig):
super().__init__()
# 确定 kv 头的数量, 支持 Grouped Query Attention (GQA)
# 如果 args.num_key_value_heads = None, 则 self.num_key_value_heads = args.num_attention_heads
# 否则 self.num_key_value_heads = args.num_key_value_heads
self.num_key_value_heads = args.num_attention_heads if args.num_key_value_heads is None else args.num_key_value_heads
# 断言 args.num_attention_heads 必须能被 self.num_key_value_heads 整除
assert args.num_attention_heads % self.num_key_value_heads == 0
# Query 头数
self.n_local_heads = args.num_attention_heads
# Key / Value 头数
self.n_local_kv_heads = self.num_key_value_heads
# 重复因子
self.n_rep = self.n_local_heads // self.n_local_kv_heads
# 每个头的维度
self.head_dim = args.hidden_size // args.num_attention_heads
# 投影层: 将隐藏状态映射到 query、key、value
self.q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
@ -163,137 +213,294 @@ class Attention(nn.Module):
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
# 检查是否支持Flash Attention
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
def forward(self,
x: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor], # 修改为接收cos和sin
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache=False,
attention_mask: Optional[torch.Tensor] = None):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
def forward(
self,
x: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor], # 接收 cos 和 sin -> cos, sin = position_embeddings
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # 用于缓存 kv 的变量
use_cache=False, # 是否开启 kv 缓存功能
attention_mask: Optional[torch.Tensor] = None # 注意力掩码矩阵, 形状为 (batch_size, seq_len)
):
# 获取 x 的维度信息
bsz, seq_len, _ = x.shape
# 投影到 query、key、value 空间
xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# 重塑为多头格式
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
# 获取位置编码
cos, sin = position_embeddings
# 应用旋转位置编码 (qk需要, k不需要)
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
cos, sin = position_embeddings
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
# kv_cache 实现: 拼接过去的 key/value
if past_key_value is not None:
xk = torch.cat([past_key_value[0], xk], dim=1)
xv = torch.cat([past_key_value[1], xv], dim=1)
past_kv = (xk, xv) if use_cache else None
# kv_cache实现
if past_key_value is not None:
xk = torch.cat([past_key_value[0], xk], dim=1)
xv = torch.cat([past_key_value[1], xv], dim=1)
past_kv = (xk, xv) if use_cache else None
# 转置为注意力计算格式: (batch, heads, seq_len, head_dim)
xq, xk, xv = (
xq.transpose(1, 2),
repeat_kv(xk, self.n_rep).transpose(1, 2),
repeat_kv(xv, self.n_rep).transpose(1, 2)
)
xq, xk, xv = (
xq.transpose(1, 2),
repeat_kv(xk, self.n_rep).transpose(1, 2),
repeat_kv(xv, self.n_rep).transpose(1, 2)
)
# 使用 Flash Attention
if self.flash and (seq_len > 1) and (past_key_value is None) and (attention_mask is None or torch.all(attention_mask == 1)):
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
# 其余情况使用 "标注注意力"
else:
# 计算注意力分数
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
# 因果掩码: 防止看到未来信息
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.triu.html
# torch.triu(input, diagonal)
# 返回矩阵 (2D 张量) 或矩阵批次的上三角部分 input, 结果张量 out 的其他元素将设置为 0
# diagonal 控制要考虑的对角线
# 如果 diagonal = 0, 则保留主对角线及其上方的所有元素
# 正值会排除主对角线以上的相同数量的对角线
# 负值会包含主对角线以下的相同数量的对角线
scores[:, :, :, -seq_len:] += torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=scores.device), diagonal=1)
if self.flash and (seq_len > 1) and (past_key_value is None) and (attention_mask is None or torch.all(attention_mask == 1)):
output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores[:, :, :, -seq_len:] += torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=scores.device), diagonal=1)
# attention_mask shape: (batch_size, seq_len)
# - 1 表示该位置是有效 token
# - 0 表示该位置是 padding
if attention_mask is not None:
# (bs, seq_len) -> (bs, 1, 1, seq_len)
# 为了让掩码能够广播 (broadcast) 到注意力分数矩阵的维度上
# 因为 scores shape: (bs, num_heads, seq_len, seq_len)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# -1e9 是一个极大的负数(-1000000000.0), 用于屏蔽无效位置
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
# 应用掩码:
# - 如果 extended_attention_mask = 1, 应用后对应位置 mask = 0 -> scores 不变 (保留)
# - 如果 extended_attention_mask = 0, 应用后对应位置 mask 接近 -inf -> scores 分数会大幅拉低, 经过 Softmax ≈ 0
scores = scores + extended_attention_mask
if attention_mask is not None:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
scores = scores + extended_attention_mask
# Softmax 归一化
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
# Dropout
scores = self.attn_dropout(scores)
# 计算注意力输出 (乘上 Value)
output = scores @ xv
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.o_proj(output))
return output, past_kv
# 重塑并输出
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.o_proj(output))
return output, past_kv
class FeedForward(nn.Module):
"""
前反馈神经网络
"""
def __init__(self, config: MiniMindConfig):
super().__init__()
# 计算 FFN 中间层大小, 确保是 64 的倍数, 是为了对 GPU 内存做对齐优化, 用来加速推理
if config.intermediate_size is None:
# SwiGLU FFN 中间层大小建议为 hidden_size * 8/3 (LLaMA 风格)
intermediate_size = int(config.hidden_size * 8 / 3)
# 对中间层大小进行 64 的倍数对齐: 向上取整到最近的 64 的倍数
# - 公式: (x + n - 1) // n * n
# - 加上 63 -> 对该值 "向下取整(//)" -> 乘上 64
config.intermediate_size = 64 * ((intermediate_size + 64 - 1) // 64)
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
# SwiGLU FFN 的三个线性层
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) # 门控层
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) # 下投影层
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) # 上投影层
self.dropout = nn.Dropout(config.dropout)
# 激活函数, 默认为 SiLU
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py
# from transformers.activations import ACT2FN
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
# SwiGLU: gate * up -> down
return self.dropout(self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)))
class MoEGate(nn.Module):
"""
MoE 门控网络
- 根据输入的隐藏状态, 动态地为每个 token 选择 top-k 个最合适的专家, 并输出其权重
- 同时可选地计算一个负载均衡辅助损失来防止专家 "垄断"
输入:
- hidden_states: (batch_size, seq_len, hidden_size)
输出:
- topk_idx: 每个 token 选择的 top-k 专家索引, 形状为 (bsz * seq_len, k)
- topk_weight: 对应的专家权重 (归一化后), 形状同上
- aux_loss: 辅助损失项 (用于训练时均衡专家负载)
"""
def __init__(self, config: MiniMindConfig):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.top_k = config.num_experts_per_tok # 每个 token 选择的专家数量
self.n_routed_experts = config.n_routed_experts # 路由专家总数
self.scoring_func = config.scoring_func
self.alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
self.scoring_func = config.scoring_func # 评分函数, 仅支持 softmax
self.alpha = config.aux_loss_alpha # 辅助损失权重, 常见为 0.01 ~ 0.1
self.seq_aux = config.seq_aux # 是否使用序列级辅助损失 (False 比较常见)
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.hidden_size
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
self.reset_parameters()
self.norm_topk_prob = config.norm_topk_prob # 是否标准化 top-k 概率, 推荐启用
self.gating_dim = config.hidden_size # 门控维度, 注意维度和隐藏层维度相同
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) # 门控权重
self.reset_parameters() # Kaiming 初始化 self.weight 权重
def reset_parameters(self) -> None:
# Kaiming 初始化权重
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
"""
:param: hidden_states (bsz, seq_len, hidden_size)
"""
# 1.记录原始形状并展平, 将所有 token 视为独立样本
bsz, seq_len, h = hidden_states.shape
hidden_states = hidden_states.view(-1, h)
# shape: (bsz * seq_len, hidden_size)
hidden_states = hidden_states.view(-1, h)
# 2.计算路由分数 (Gating Scores)
# logits: (bsz * seq_len, n_routed_experts)
# 这一步计算 token 与每个专家的相似度
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.nn.functional.linear.html
# - hidden_states: (bsz * seq_len, hidden_size)
# - self.weight: (n_routed_experts, hidden_size) -> self.gating_dim = config.hidden_size
# - logits: hidden_states @ weight.T = (bsz * seq_len, n_routed_experts)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
# 使用 softmax 转换为概率分布, 和为1
# scores: (bsz * seq_len, n_routed_experts)
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
# 3.选择 top-k 专家
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.topk.html
# 给定 scores 张量在给定维度(dim=-1)上最大的 k 个元素 -> (values, indices)
# - topk_weight: 对应专家概率 (bsz * seq_len, top_k)
# - topk_idx: 对应专家索引 (bsz * seq_len, top_k)
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
# 4.标准化 top-k 概率
# 如果选择了多个专家, 则将它们的概率重新归一化, 使其总和为 1
if self.top_k > 1 and self.norm_topk_prob:
# 分母
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
# 权重再归一化后的概率
topk_weight = topk_weight / denominator
if self.training and self.alpha > 0.0:
# 5.计算辅助损失 (用于负载均衡)
if self.training and self.alpha > 0.0: # self.training 继承于 torch.nn.Module
# (bsz * seq_len, n_routed_experts)
scores_for_aux = scores
aux_topk = self.top_k
# (bsz * seq_len, top_k) -> (bsz, seq_len * top_k)
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
# 分支1: 序列级辅助损失 (较少见)
if self.seq_aux:
# (bsz, seq_len, n_routed_experts)
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
# ce 是 Counter for Experts / Cumulative Expert usage count
# ce shape (bsz, n_routed_experts)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(1, topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
seq_len * aux_topk / self.n_routed_experts)
# 目标张量为 ce, 作用是统计每个 sequence 中各个专家被选中的频率
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.Tensor.scatter_add_.html
# - Tensor.scatter_add_(dim, index, src)
# - 把 src[i] 的值, 累加到 "目标张量" 中由 index[i] 指定的 dim 维度上的位置
# - self, index 和 src 必须具有相同的维度数
# div_(seq_len * aux_topk / self.n_routed_experts) 是归一化因子
# 如果完美均衡, 每个专家应被选中 (seq_len * top_k) / n_routed_experts 次
# 因此 ce.div_(xxx) 之后, ce[i][j] 变成 "归一化的期望频次"
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.div.html
# - torch.div(input, other, *, rounding_mode=None, out=None)
# - 将输入 input 的每个元素除以 other 的相应元素
ce.scatter_add_(
# 在 "专家维度" 上累加
dim = 1,
# 每个 token 的 top-k 专家编号
index = topk_idx_for_aux_loss,
# 对每个选中操作, 这里为每次加 1
src = torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)
# 计算辅助损失
# 目标是最小化专家负载的不均衡性, 即让 f_i ≈ 1/n
# 公式: aux_loss = alpha * (P_i * f_i).sum()
# - P_i: 第 i 个专家的平均预测概率, 来自 scores_for_seq_aux.mean(dim=1)
# - f_i: 第 i 个专家的实际被选中频率 (归一化后), 来自 ce
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
######################################################
# scores_for_seq_aux.mean(dim=1)
# -> 每个样本中, 各个专家的平均预测概率 (P_i)
# ce * scores_for_seq_aux.mean(dim=1)
# -> 每个样本中, 对每个专家的 频率 × 预测概率 (P_i * f_i)
# .sum(dim=1).mean()
# -> 每个样本的负载不均衡程度总和, 并平均到整个 batch
# self.alpha
# -> 缩放, 避免主导主 loss
######################################################
# 分支2: Token 级/Batch 级辅助损失 (默认)
else:
# 1.计算 expert 被选中的实际频率 f_i
# - mask_ce shape: (bsz * seq_len * n_routed_experts, n_routed_experts)
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.nn.functional.one_hot.html
# 接收形状为 (*) 的索引值的 LongTensor, 并返回形状为 (*, num_classes) 的张量
# - 每行代表一个选择事件
# - 每列代表一个专家 (num_classes=self.n_routed_experts)
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
# 每个专家在 batch 中被选中的比例
# ce shape: (n_routed_experts,)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
# fi: 归一化后的频率, 理想情况下 fi 应接近 1
fi = ce * self.n_routed_experts
# 2.计算 expert 的平均路由概率 P_i
# Pi shape: (n_routed_experts,)
Pi = scores_for_aux.mean(0)
# 3.计算辅助损失
# 计算点积并求和: aux_loss = alpha * (P_i * f_i).sum()
# 期望分布是均匀的, 这个 Loss 最小化时通常意味着均衡
aux_loss = (Pi * fi).sum() * self.alpha
# 不需要计算辅助损失
else:
aux_loss = scores.new_zeros(1).squeeze()
# 返回结果
# - topk_idx: 选中的专家索引: (bsz * seq_len, top_k)
# - topk_weight: 选中的专家权重: (bsz * seq_len, top_k)
# - aux_loss: 辅助损失, 用于负载均衡
return topk_idx, topk_weight, aux_loss
class MOEFeedForward(nn.Module):
"""
MOE 前反馈神经网络
"""
def __init__(self, config: MiniMindConfig):
super().__init__()
self.config = config
# 创建路由专家列表, 只有被 Gate 选中的才会被计算
self.experts = nn.ModuleList([
FeedForward(config)
for _ in range(config.n_routed_experts)
])
# 门控网络: 输入是 x, 输出是 top-k 的专家索引和对应权重 (还有辅助损失 aux_loss)
self.gate = MoEGate(config)
# 创建共享专家列表 (可选)
if config.n_shared_experts > 0:
self.shared_experts = nn.ModuleList([
FeedForward(config)
@ -301,147 +508,351 @@ class MOEFeedForward(nn.Module):
])
def forward(self, x):
"""
:param: x: (bsz, seq_len, hidden_size)
"""
# 原始输入, 用于后续共享专家相加
identity = x
orig_shape = x.shape
bsz, seq_len, _ = x.shape
# 使用门控机制选择专家
# 通过 Gate 获取路由结果
# - topk_idx: 选中的专家索引: (bsz * seq_len, top_k)
# - topk_weight: 选中的专家概率: (bsz * seq_len, top_k)
# - aux_loss: 辅助损失, 用于负载均衡
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
# 准备数据
x = x.view(-1, x.shape[-1]) # (bsz * seq_len, hidden_size)
flat_topk_idx = topk_idx.view(-1) # (bsz * seq_len * top_k)
# 训练模式: 使用专家并行处理
if self.training:
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=x.dtype)
# 将 x 重复扩展以匹配专家
# 如果 k=2, 每个 token 需要被处理两次
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.repeat_interleave.html
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.Tensor.repeat.html
# - Tensor.repeat(*repeats)
# - repeats: 每个元素重复的次数
# - dim: 沿着指定维度进行重复
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0) # (bsz * seq_len * top_k, hidden_size)
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.empty_like.html
# 预分配输出空间 (返回一个大小与 x 相同的未初始化张量)
y = torch.empty_like(x, dtype=x.dtype) # (bsz * seq_len * top_k, hidden_size)
# 遍历每一个专家
for i, expert in enumerate(self.experts):
expert_out = expert(x[flat_topk_idx == i])
if expert_out.shape[0] > 0: y[flat_topk_idx == i] = expert_out.to(y.dtype)
else: y[flat_topk_idx == i] = expert_out.to(y.dtype) + 0 * sum(p.sum() for p in expert.parameters())
# 制作掩码: 找出所有分配给专家 i 的 token 位置
mask = (flat_topk_idx == i)
# x[mask]: 只属于该专家的输入
expert_out = expert(x[mask])
# 专家输出不为空时, 将专家输出写入到 y 中对应位置
if expert_out.shape[0] > 0:
y[mask] = expert_out.to(y.dtype)
# 否则, 添加一个零梯度的项, 保证即使没有 token 被分配, 参数仍有梯度
else:
# 这里的 0 * sum(...) 确保该专家的参数在计算图中, 避免 DDP 报错 (unused parameters)
y[mask] = expert_out.to(y.dtype) + 0 * sum(p.sum() for p in expert.parameters())
# 聚合结果
# y 目前的形状是 (bsz * seq_len * top_k, hidden_size), 代表 "展开的" 专家输出
# 需要将 y 加权合并为 -> (bsz * seq_len, hidden_size)
# - topk_weight (bsz * seq_len, top_k)
# - y.view(*topk_weight.shape, -1): (bsz * seq_len, top_k, hidden_size)
# - topk_weight.unsqueeze(-1): (bsz * seq_len, top_k, 1)
# - 广播乘法 (N, k, h) * (N, k, 1) = (N, k, h)
# - 最后在 top_k 维度上求和 (dim=1): (bsz * seq_len, top_k, hidden_size) -> (bsz * seq_len, hidden_size)
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
# y 已经是 (bsz, seq_len, hidden_size) 了, 这里的 view 只是为了保险
y = y.view(*orig_shape)
# 推理模式: 使用优化的专家推理
# - 使用优化后的串行处理 (减少显存占用, 避免大量 mask 操作)
else:
# flat_topk_idx: (bsz * seq_len * top_k)
# topk_weight.view(-1, 1) 展平了所有维度: (bsz * seq_len * top_k, 1)
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
# 添加共享专家的输出
if self.config.n_shared_experts > 0:
for expert in self.shared_experts:
y = y + expert(identity)
self.aux_loss = aux_loss
# (bsz, seq_len, hidden_size)
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
"""
推理模式下高效 MoE 实现
不使用 repeat_interleave 复制 x (省显存), 而是利用索引重排
:param x: (bsz * seq_len, hidden_size)
:param flat_expert_indices: (bsz * seq_len * top_k) 专家编号
:param flat_expert_weights: (bsz * seq_len * top_k, 1) 对应的权重
"""
# 结果缓存 (bsz * seq_len, hidden_size)
expert_cache = torch.zeros_like(x)
# 1.对所有 token 的任务分配进行排序
# idxs 是排序后的索引, 能够把 "分配给专家0的任务", "分配给专家1的任务" 聚在一起
idxs = flat_expert_indices.argsort()
# 2.计算每个专家处理多少个 token
# - bincount 统计每个专家出现的次数
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.bincount.html
# - cumsum 返回在 dim 维度上的累积和, 用于切片
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.cumsum.html
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
# 3.计算排序后的索引对应原始的 token_id
# - flat_expert_indices 长度是 N*k, 是展开的 (bsz * seq_len * top_k)
# - // top_k 操作将 "展开后的索引" 映射回 "原始 token 的行号"
token_idxs = idxs // self.config.num_experts_per_tok
# 当tokens_per_expert = [6, 15, 20, 26]tokens_per_expert.shape[0]即为专家数量此时为4
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token每个token有可能被多个专家处理这取决于num_experts_per_tok
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
#####################################
# 当 tokens_per_expert = [6, 15, 20, 26], tokens_per_expert.shape[0] 即为专家数量此时为4
# 且 token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
# 意味 token_idxs[:6] -> [3, 7, 19, 21, 24, 25] 这6个位置属于 "专家0" 处理的 token (每个token有可能被多个专家处理, 这取决于 num_experts_per_tok)
# 接下来9个位置 token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...] 属于 "专家1" 处理的 token...依此类推
#####################################
# 遍历每个专家 (根据 tokens_per_expert 切片)
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
# 该专家在这个 batch 没有任务
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens).to(expert_cache.dtype)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
# 获取属于专家 i 的所有 token 的原始行号
exp_token_idx = token_idxs[start_idx:end_idx]
# 从 x 中取出这些 token
expert_tokens = x[exp_token_idx]
# 前向计算
expert_out = expert(expert_tokens).to(expert_cache.dtype)
# 乘上对应的权重
# - idxs[start_idx:end_idx] 取出的是 "flat_expert_weights" 对应的索引
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
# 将结果累加回 expert_cache
# scatter_add_ 处理这种情况: 如果一个 token 同时选择了专家 A 和 专家 B
# 它的结果需要是 (OutA * wA) + (OutB * wB)
# 这里利用 scatter_add 根据 token 索引将结果加回去
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
#####################################
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.Tensor.scatter_add_.html
# - Tensor.scatter_add_(dim, index, src)
# - 把 src[i] 的值, 累加到 "目标张量" 中由 index[i] 指定的 dim 维度上的位置
# - self, index 和 src 必须具有相同的维度数
#####################################
# (bsz * seq_len, hidden_size)
return expert_cache
class MiniMindBlock(nn.Module):
"""
MiniMind 模型的一个 Transformer
"""
def __init__(self, layer_id: int, config: MiniMindConfig):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_dim = config.hidden_size // config.num_attention_heads
# 自注意力层 (Attention 模块)
self.self_attn = Attention(config)
self.layer_id = layer_id
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# 输入前标准化
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# 注意力后标准化
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# 根据配置选择FFN类型: 普通FFN / MoE FFN
self.mlp = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
def forward(self, hidden_states, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None):
"""
:param hidden_states: (batch_size, seq_len, hidden_size)
:param position_embeddings: Tuple[torch.Tensor, torch.Tensor]
:param past_key_value: Tuple[torch.Tensor, torch.Tensor]
:param use_cache: 是否开启 kv 缓存功能
:param attention_mask: 注意力掩码矩阵, 形状为 (batch_size, seq_len)
"""
# 第一个残差连接:输入层标准化 -> 注意力 -> 残差连接
residual = hidden_states
hidden_states, present_key_value = self.self_attn(
self.input_layernorm(hidden_states), position_embeddings,
past_key_value, use_cache, attention_mask
self.input_layernorm(hidden_states),
position_embeddings,
past_key_value,
use_cache,
attention_mask
)
hidden_states += residual
# 第二个残差连接:注意力后标准化 -> FFN -> 残差连接
hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states))
# 返回 输出和缓存的键值对
return hidden_states, present_key_value
class MiniMindModel(nn.Module):
"""
MiniMind 模型主类
"""
def __init__(self, config: MiniMindConfig):
super().__init__()
self.config = config
self.vocab_size, self.num_hidden_layers = config.vocab_size, config.num_hidden_layers
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.vocab_size = config.vocab_size # 词表大小
self.num_hidden_layers = config.num_hidden_layers # 隐藏层数量
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) # 词嵌入层
self.dropout = nn.Dropout(config.dropout)
# 构建多个相同的 Transformer 块 (MiniMindBlock)
# 每个块包含自注意力 + FFN, 通过 ModuleList 管理参数
self.layers = nn.ModuleList([MiniMindBlock(l, config) for l in range(self.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # 最终标准化层
freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.hidden_size // config.num_attention_heads,
end=config.max_position_embeddings, rope_base=config.rope_theta,
rope_scaling=config.rope_scaling)
# 预计算旋转位置编码
freqs_cos, freqs_sin = precompute_freqs_cis(
dim=config.hidden_size // config.num_attention_heads,
end=config.max_position_embeddings, rope_base=config.rope_theta,
rope_scaling=config.rope_scaling
)
# 注册为缓冲区, 不会被优化
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
def forward(self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
**kwargs):
def forward(
self,
# 输入 token 序列, 形状:(batch_size, seq_len)
input_ids: Optional[torch.Tensor] = None,
# 注意力掩码1 表示有效 token, 0 表示 padding
attention_mask: Optional[torch.Tensor] = None,
# KV 缓存列表 (用于生成任务)
# - 每个元素为 (key, value) 对
# - 形状为 [(bs, seq_len_k, num_heads, head_dim), ...]
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
# 是否启用缓存 (推理加速)
use_cache: bool = False,
**kwargs
):
# 提取输入维度
batch_size, seq_length = input_ids.shape
if hasattr(past_key_values, 'layers'): past_key_values = None
# KV 缓存处理
#######################################################################################
# https://docs.python.org/zh-cn/3.14/library/functions.html#hasattr
# - 如果字符串是对象的属性之一的名称, 则返回 True, 否则返回 False
if hasattr(past_key_values, 'layers'):
# 如果传进来的 past_key_values 是一个带 .layers 属性的对象 (比如 HuggingFace 封装的输出)
# 那就把它当作无效缓存处理, 清空为 None
past_key_values = None
# 对 past_key_values 进行默认值填充
# 如果是 None、空列表 [] 或其他 "假值(falsy)"
# 就用一个长度为 num_hidden_layers 的全 None 列表替代 -> [None] * len(self.layers)
past_key_values = past_key_values or [None] * len(self.layers)
#######################################################################################
# 计算起始位置用于KV缓存
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
# 词嵌入 + dropout
hidden_states = self.dropout(self.embed_tokens(input_ids))
# 获取位置编码
position_embeddings = (
self.freqs_cos[start_pos:start_pos + seq_length],
self.freqs_sin[start_pos:start_pos + seq_length]
)
# 逐层前向传播
presents = []
# Transformer 块堆叠处理
for layer_idx, (layer, past_key_value) in enumerate(zip(self.layers, past_key_values)):
# 每层 MiniMindBlock 接受如下参数
# 并输出 "下一个时刻隐藏状态 hidden_states" 和 "KV 缓存对 present"
hidden_states, present = layer(
hidden_states,
position_embeddings,
past_key_value=past_key_value,
use_cache=use_cache,
attention_mask=attention_mask
hidden_states, # 当前隐藏状态
position_embeddings, # 对应的 RoPE 位置编码 (cos, sin)
past_key_value=past_key_value, # 上一时刻缓存的 KV 缓存 (key, value) tuple
use_cache=use_cache, # 是否启用缓存
attention_mask=attention_mask # 注意力掩码
)
# 存储每层的 kv 缓存
presents.append(present)
# 最终标准化
hidden_states = self.norm(hidden_states)
# 计算 MoE 辅助损失
# 对所有使用了 MOEFeedForward 的层, 累加其 .aux_loss 属性 (属性来自 MOEFeedForward, 值来自 MoEGate)
# https://docs.pytorch.ac.cn/docs/stable/generated/torch.Tensor.new_zeros.html
# - Tensor.new_zeros(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False)
# - 返回一个大小为 size 的、填充有 0 的 Tensor;
# - 默认情况下, 返回的 Tensor 具有与此 Tensor 相同的 torch.dtype 和 torch.device
aux_loss = sum([l.mlp.aux_loss for l in self.layers if isinstance(l.mlp, MOEFeedForward)], hidden_states.new_zeros(1).squeeze())
######################################################################################
# 为什么用 sum(..., start) 而不是直接 sum(...)?
# - 若无任何 MoE 层, 则列表为空 -> sum([]) 返回 int(0)
# - 但模型其他部分使用 torch.Tensor, int 和 Tensor 相加会报错
# - 因此必须显式指定一个与 hidden_states 同 device/dtype 的零标量作为初始值
######################################################################################
# hidden_states.new_zeros(1).squeeze()
# - hidden_states.new_zeros(1) 返回一个标量张量 [0.]
# - squeeze() 删除维度 1, 变成标量 tensor: tensor(0.)
#
##### 等效代码 ########################################################################
# - zero_scalar = hidden_states.new_zeros(1).squeeze() # 创建一个标量 tensor(0.)
# - aux_loss = zero_scalar # aux_loss 初始值为 tensor(0.)
# - for layer in self.layers:
# - if isinstance(layer.mlp, MOEFeedForward): # 检查是否是 MoE 类型
# - aux_loss = aux_loss + layer.mlp.aux_loss # 累加辅助损失
######################################################################################
# 返回最终隐藏状态、缓存和 MoE 辅助损失
return hidden_states, presents, aux_loss
class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
"""
MiniMindForCausalLM , 用于构建因果语言模型
- MiniMindModel 主干与线性输出头 (lm_head) 组合, 实现完整的文本生成能力
- 兼容 HuggingFace Transformers 的训练/推理接口 ( model.generate())
- 并支持自回归解码KV 缓存损失计算等标准功能
# 参考链接
https://huggingface.co/docs/transformers/main_classes/text_generation
"""
config_class = MiniMindConfig
def __init__(self, config: MiniMindConfig = None):
# 初始化配置项,如果未提供则使用默认参数初始化
self.config = config or MiniMindConfig()
super().__init__(self.config)
# MiniMind 模型主体
self.model = MiniMindModel(self.config)
# 语言模型头 (输出头)
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
# 权重共享:词嵌入和语言模型头共享权重
# - 减少参数量 ~15~20%(约 vocab_size * hidden_size
# - 提升泛化能力, 增强词向量与输出分布的一致性 (LLaMA、GPT 系列标准做法)
# - 此操作需在 super().__init__() 后执行, 因为父类可能初始化参数
self.model.embed_tokens.weight = self.lm_head.weight
def forward(self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
logits_to_keep: Union[int, torch.Tensor] = 0,
**args):
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
logits_to_keep: Union[int, torch.Tensor] = 0,
**args
):
# 主干模型的前向传播
hidden_states, past_key_values, aux_loss = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
@ -449,15 +860,29 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
use_cache=use_cache,
**args
)
# 计算logits (只保留最后 logits_to_keep 个位置以节省内存)
# - 适用于长上下文推理中只关心生成末尾结果的场景
# - 若 logits_to_keep = 0, 则保留全部序列
# - 若为整数 N, 则取最后 N 个 token 的 logits
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# 标签对齐:采用标准的因果语言建模损失计算方式
shift_logits = logits[..., :-1, :].contiguous() # 去掉最后一个位置的预测
shift_labels = labels[..., 1:].contiguous() # 去掉第一个位置的标签
# 交叉熵损失
loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100)
# 构建输出
output = CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
output.aux_loss = aux_loss
return output
# CausalLMOutputWithPast 实例
# https://huggingface.co/docs/transformers/v5.0.0/en/main_classes/output#transformers.modeling_outputs.CausalLMOutputWithPast
# - loss (torch.Tensor): 训练时计算的交叉熵损失; 推理时为 None
# - logits (torch.Tensor): 预测得分, 形状: (batch_size, seq_len, vocab_size) 或截断后长度
# - past_key_values (List[Tuple[torch.Tensor]]): 每层的 KV 缓存, 用于后续 token 生成
# - hidden_states (torch.Tensor): 所有 Transformer 层输出, 默认不返回, 但继承类提供
# - aux_loss (torch.Tensor, 自定义字段): MoE 辅助损失, 仅在启用且训练时非零
return output

View File

@ -1,12 +1,13 @@
import os
import sys
# 设置包名
__package__ = "trainer"
# 将项目根目录添加到系统路径,确保能够导入同级目录的模块
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import warnings
import torch
import torch.nn.functional as F
import torch.distributed as dist
@ -18,52 +19,101 @@ from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import SFTDataset
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
# 忽略所有警告信息, 避免训练过程中输出过多警告
import warnings
warnings.filterwarnings('ignore')
def distillation_loss(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'):
"""
计算知识蒸馏的 KL 散度损失
使用 Softmax KL Divergence 衡量学生模型和教师模型预测分布的差异
温度参数用于控制分布的平滑程度, 温度越高分布越平滑
Args:
- student_logits: 学生模型的原始输出 logits, shape [batch, seq_len, vocab_size]
- teacher_logits: 教师模型的原始输出 logits, shape [batch, seq_len, vocab_size]
- temperature: 蒸馏温度系数, 用于平滑 softmax 分布, 默认为 1.0
- reduction: 损失聚合方式, 默认为 'batchmean' (按批次平均)
Returns:
- float: 计算得到的蒸馏损失值, 已乘以温度平方
"""
# 计算教师模型在温度下的概率分布, 并 detach 避免梯度传播到教师模型
with torch.no_grad():
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1).detach()
# 计算学生模型在温度下的对数概率分布
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
# 计算 KL 散度损失
kl = F.kl_div(
student_log_probs,
teacher_probs,
reduction=reduction
)
# 乘以温度平方进行缩放, 保持梯度大小与标准交叉熵一致
return (temperature ** 2) * kl
def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_step=0, wandb=None, alpha=0.0, temperature=1.0):
"""
训练一个 epoch
执行知识蒸馏训练的前向传播损失计算和反向传播更新
包含标准交叉熵损失和蒸馏损失两种监督信号
Args:
- epoch: 当前训练轮次 ( 0 开始计数)
- loader: 数据加载器, 提供批次数据
- iters: 当前 epoch 的总迭代次数
- teacher_model: 教师模型, 用于生成软标签进行蒸馏
- lm_config_student: 学生模型配置, 用于判断是否为 MoE 架构
- start_step: 起始步数, 用于从检查点恢复训练
- wandb: Weights & Biases 日志对象, 用于记录训练指标
- alpha: 损失权重, 平衡 CE 损失和蒸馏损失, 总损失 = alpha*CE + (1-alpha)*KL
- temperature: 蒸馏温度, 用于控制分布平滑程度
"""
start_time = time.time()
# 设置教师模型为评估模式, 禁用 dropout 等随机操作
if teacher_model is not None:
teacher_model.eval()
# 冻结教师模型参数, 不计算梯度
teacher_model.requires_grad_(False)
# 遍历数据加载器中的每个批次
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
# 将数据移动到指定设备
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
# 创建损失掩码, 标记哪些位置需要计算损失 (ignore_index=-100 的位置为 0)
loss_mask = (labels[..., 1:] != -100).float()
# 使用余弦退火策略计算当前学习率
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
# 更新优化器中所有参数组的学习率
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# 前向传播(学生模型)
# 前向传播 (学生模型)
# - 在 __name__ == "__main__" 中定义
# - autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(dtype=dtype)
with autocast_ctx:
res = model(input_ids)
# 提取学生模型的 logits, 去掉最后一个位置的预测 (因为标签向右偏移了一位)
student_logits = res.logits[..., :-1, :].contiguous()
# 教师模型前向传播只在eval & no_grad
# 教师模型前向传播 (只在 eval & no_grad 模式下)
if teacher_model is not None:
with torch.no_grad():
teacher_logits = teacher_model(input_ids).logits[..., :-1, :].contiguous()
# 如果学生模型和教师模型的词表大小不同, 只取教师模型前 vocab_size_student 个
vocab_size_student = student_logits.size(-1)
teacher_logits = teacher_logits[..., :vocab_size_student]
# ========== 计算损失 ==========
# 1) Ground-Truth CE Loss
# 计算损失
# 1) Ground-Truth CE Loss (标准交叉熵损失)
shift_labels = labels[..., 1:].contiguous()
loss_mask_flat = loss_mask.view(-1)
ce_loss = F.cross_entropy(
@ -72,42 +122,61 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
ignore_index=-100,
reduction='none'
)
# 计算加权平均损失, 忽略被 mask 的位置
ce_loss_raw = torch.sum(ce_loss * loss_mask_flat) / (loss_mask_flat.sum() + 1e-8)
if lm_config_student.use_moe: ce_loss = ce_loss_raw + res.aux_loss
else: ce_loss = ce_loss_raw
# 2) Distillation Loss
# 如果是 MoE 模型, 加上辅助损失
if lm_config_student.use_moe:
ce_loss = ce_loss_raw + res.aux_loss
else:
ce_loss = ce_loss_raw
# 2) Distillation Loss (蒸馏损失)
if teacher_model is not None:
# 只在被 mask 的位置上计算蒸馏损失
distill_loss = distillation_loss(
student_logits.view(-1, student_logits.size(-1))[loss_mask_flat == 1],
teacher_logits.view(-1, teacher_logits.size(-1))[loss_mask_flat == 1],
temperature=temperature
)
else:
# 如果没有教师模型, 蒸馏损失为 0
distill_loss = torch.tensor(0.0, device=args.device)
# 3) 总损失 = alpha * CE + (1-alpha) * Distill
# 3) 总损失 = alpha * CE + (1-alpha) * Distill,再除以梯度累积步数
loss = (alpha * ce_loss + (1 - alpha) * distill_loss) / args.accumulation_steps
# 反向传播
scaler.scale(loss).backward()
# 梯度累积步数达到后, 执行优化器更新
if (step + 1) % args.accumulation_steps == 0:
# 取消梯度缩放, 准备更新
scaler.unscale_(optimizer)
# 梯度裁剪, 防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
# 执行优化器更新
scaler.step(optimizer)
# 更新缩放器
scaler.update()
# 清空梯度, 将 None 赋值给梯度以节省内存
optimizer.zero_grad(set_to_none=True)
# 记录日志
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
# 恢复实际的损失值 (乘以累积步数)
current_loss = loss.item() * args.accumulation_steps
current_ce_loss = ce_loss_raw.item()
current_aux_loss = res.aux_loss.item() if lm_config_student.use_moe else 0.0
current_lr = optimizer.param_groups[-1]['lr']
# 估算剩余时间
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
# 打印训练日志
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, ce: {current_ce_loss:.4f}, aux_loss: {current_aux_loss:.4f}, distill: {distill_loss.item():.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
# 记录到 wandb
if wandb:
wandb.log({
"loss": current_loss,
@ -118,18 +187,26 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
"epoch_time": eta_min
})
# 保存检查点
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
# 根据是否为 MoE 模型添加后缀
moe_suffix = '_moe' if lm_config_student.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config_student.hidden_size}{moe_suffix}.pth'
# 获取原始模型 (解包 DDP)
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
# 保存模型权重为半精度格式
state_dict = raw_model.state_dict()
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
# 保存完整检查点 (包含优化器状态等)
lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
# 恢复训练模式
model.train()
# 清理内存
del state_dict
# 清理本批次的数据, 释放显存
del input_ids, labels, loss_mask, res, student_logits, ce_loss, distill_loss, loss
@ -147,40 +224,42 @@ if __name__ == "__main__":
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
parser.add_argument("--max_seq_len", type=int, default=340, help="训练的最大截断长度中文1token≈1.5~1.7字符)")
parser.add_argument("--max_seq_len", type=int, default=340, help="训练的最大截断长度 (中文1token≈1.5~1.7字符)")
parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径")
parser.add_argument('--student_hidden_size', default=512, type=int, help="学生模型隐藏层维度")
parser.add_argument('--student_num_layers', default=8, type=int, help="学生模型隐藏层数量")
parser.add_argument('--teacher_hidden_size', default=768, type=int, help="教师模型隐藏层维度")
parser.add_argument('--teacher_num_layers', default=16, type=int, help="教师模型隐藏层数量")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构 (0=否, 1=是)")
parser.add_argument('--from_student_weight', default='full_sft', type=str, help="学生模型基于哪个权重")
parser.add_argument('--from_teacher_weight', default='full_sft', type=str, help="教师模型基于哪个权重")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument('--alpha', default=0.5, type=float, help="CE损失权重总损失=alpha*CE+(1-alpha)*KL")
parser.add_argument('--temperature', default=1.5, type=float, help="蒸馏温度推荐范围1.0-2.0")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训 (0=否, 1=是)")
parser.add_argument('--alpha', default=0.5, type=float, help="CE损失权重, 总损失=alpha*CE+(1-alpha)*KL")
parser.add_argument('--temperature', default=1.5, type=float, help="蒸馏温度 (推荐范围1.0-2.0)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Distillation", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速 (0=否, 1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
# ========== 1.初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
# ========== 2.配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config_student = MiniMindConfig(hidden_size=args.student_hidden_size, num_hidden_layers=args.student_num_layers, use_moe=bool(args.use_moe))
lm_config_teacher = MiniMindConfig(hidden_size=args.teacher_hidden_size, num_hidden_layers=args.teacher_num_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config_student, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
# ========== 3.配置混合精度训练 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# https://docs.pytorch.ac.cn/docs/stable/amp
# 新版废弃: autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=dtype)
# ========== 4. 配wandb ==========
# ========== 4.配置 wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
@ -189,22 +268,32 @@ if __name__ == "__main__":
wandb_run_name = f"MiniMind-Distill-S{args.student_hidden_size}T{args.teacher_hidden_size}-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义学生和教师模型 ==========
# ========== 5.定义学生和教师模型 ==========
# 学生模型和分词器
model, tokenizer = init_model(lm_config_student, args.from_student_weight, device=args.device)
# 使用 torch.compile 加速模型 (如果支持)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
Logger(f'学生模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
# 教师模型
teacher_model, _ = init_model(lm_config_teacher, args.from_teacher_weight, device=args.device)
teacher_model.eval()
teacher_model.requires_grad_(False)
Logger(f'教师模型总参数量:{sum(p.numel() for p in teacher_model.parameters()) / 1e6:.3f} M')
# 创建训练数据集
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
# 创建分布式采样器 (如果在分布式环境下)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
# 创建梯度缩放器 (混合精度训练需要)
# https://docs.pytorch.ac.cn/docs/stable/amp.html#gradient-scaling
# 新版废弃: scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
scaler = torch.amp.GradScaler("cuda", enabled=(args.dtype == 'float16'))
# 创建优化器
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
# ========== 6. 从ckp恢复状态 ==========
# ========== 6.从检查点恢复训练状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
@ -213,23 +302,29 @@ if __name__ == "__main__":
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. DDP包模型 ==========
# ========== 7.使用 DDP 模型 (如果在分布式环境下) ==========
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
# ========== 8.开始训练 ==========
for epoch in range(start_epoch, args.epochs):
# 设置分布式采样器的 epoch (打乱数据)
train_sampler and train_sampler.set_epoch(epoch)
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
setup_seed(42 + epoch)
# 随机打乱数据集索引
indices = torch.randperm(len(train_ds)).tolist()
# 计算需要跳过的批次数量
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
# 创建批次采样器
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
# 如果有跳过的批次, 打印提示信息并恢复训练
if skip > 0:
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
train_epoch(epoch, loader, len(loader) + skip, teacher_model, lm_config_student, start_step, wandb, args.alpha, args.temperature)
else:
train_epoch(epoch, loader, len(loader), teacher_model, lm_config_student, 0, wandb, args.alpha, args.temperature)
# ========== 9. 清理分布进程 ==========
# ========== 9.清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -1,12 +1,13 @@
import os
import sys
# 设置包名
__package__ = "trainer"
# 将项目根目录添加到系统路径, 确保能够导入同级目录的模块
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import warnings
import torch
import torch.nn.functional as F
import torch.distributed as dist
@ -18,10 +19,25 @@ from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import DPODataset
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
# 忽略所有警告信息, 保持输出整洁
import warnings
warnings.filterwarnings('ignore')
def logits_to_log_probs(logits, labels):
"""
logits 转换为对数概率
对每个位置的 logits 计算 log_softmax, 然后根据 labels 提取对应的对数概率
这是 DPO 训练中的基础操作, 用于计算策略模型和参考模型的 log probabilities
Args:
- logits: 模型输出的 logits, shape: (batch_size, seq_len, vocab_size)
- labels: 目标 token 序列, shape: (batch_size, seq_len)
Returns:
- log_probs_per_token: 每个 token 的对数概率, shape: (batch_size, seq_len)
"""
# logits shape: (batch_size, seq_len, vocab_size)
# labels shape: (batch_size, seq_len)
# log_probs shape: (batch_size, seq_len)
@ -31,19 +47,36 @@ def logits_to_log_probs(logits, labels):
def dpo_loss(ref_log_probs, policy_log_probs, mask, beta):
"""
计算 DPO (Direct Preference Optimization) 损失
DPO 是一种直接根据人类偏好数据优化语言模型的方法, 不需要训练奖励模型
通过比较策略模型和参考模型在 chosen/rejected 样本上的对数概率比率来计算损失
Args:
- ref_log_probs: 参考模型对数概率, shape: (batch_size, seq_len)
- policy_log_probs: 策略模型对数概率, shape: (batch_size, seq_len)
- mask: 有效 token 掩码, shape: (batch_size, seq_len)
- beta: DPO 温度参数, 控制模型偏离参考模型的程度
Returns:
- loss: 平均 DPO 损失值
"""
# ref_log_probs 和 policy_log_probs 都是 shape: (batch_size, seq_len)
# https://github.com/jingyaogong/minimind/issues/298
seq_lengths = mask.sum(dim=1, keepdim=True).clamp_min(1e-8) # 防止零长度mask导致除零NaN
# 计算每个序列的有效长度, 用于对 log probs 进行归一化
seq_lengths = mask.sum(dim=1, keepdim=True).clamp_min(1e-8) # 防止零长度 mask 导致除零 NaN
ref_log_probs = (ref_log_probs * mask).sum(dim=1) / seq_lengths.squeeze()
policy_log_probs = (policy_log_probs * mask).sum(dim=1) / seq_lengths.squeeze()
# 将 chosen 和 rejected 数据分开
# 将 chosen 和 rejected 数据分开, 批次的前半部分是 chosen, 后半部分是 rejected
batch_size = ref_log_probs.shape[0]
chosen_ref_log_probs = ref_log_probs[:batch_size // 2]
reject_ref_log_probs = ref_log_probs[batch_size // 2:]
chosen_policy_log_probs = policy_log_probs[:batch_size // 2]
reject_policy_log_probs = policy_log_probs[batch_size // 2:]
# 计算对数概率比率, 衡量策略模型相比参考模型的改进程度
pi_logratios = chosen_policy_log_probs - reject_policy_log_probs
ref_logratios = chosen_ref_log_probs - reject_ref_log_probs
logits = pi_logratios - ref_logratios
@ -52,39 +85,62 @@ def dpo_loss(ref_log_probs, policy_log_probs, mask, beta):
def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=None, beta=0.1):
"""
执行一个 epoch DPO 训练
Args:
- epoch: 当前训练轮次
- loader: 数据加载器
- iters: 当前 epoch 的总迭代次数
- ref_model: 参考模型 (冻结参数, 不更新)
- lm_config: 模型配置对象
- start_step: 从此步骤开始训练 (用于恢复训练)
- wandb: Weights & Biases 日志对象
- beta: DPO 温度参数
"""
start_time = time.time()
for step, batch in enumerate(loader, start=start_step + 1):
# 从批次中提取 chosen 和 rejected 数据, 并移动到目标设备
x_chosen = batch['x_chosen'].to(args.device)
x_rejected = batch['x_rejected'].to(args.device)
y_chosen = batch['y_chosen'].to(args.device)
y_rejected = batch['y_rejected'].to(args.device)
mask_chosen = batch['mask_chosen'].to(args.device)
mask_rejected = batch['mask_rejected'].to(args.device)
# 将 chosen 和 rejected 数据拼接在一起, 形成完整的批次
# 批次的前半部分是 chosen (用户偏好的回复), 后半部分是 rejected (用户不喜欢的回复)
x = torch.cat([x_chosen, x_rejected], dim=0)
y = torch.cat([y_chosen, y_rejected], dim=0)
mask = torch.cat([mask_chosen, mask_rejected], dim=0)
# 根据当前步骤动态计算学习率 (余弦退火策略)
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with autocast_ctx:
# 参考模型前向传播, 不计算梯度 (frozen)
with torch.no_grad():
ref_outputs = ref_model(x)
ref_logits = ref_outputs.logits
ref_log_probs = logits_to_log_probs(ref_logits, y)
# 策略模型前向传播, 计算 logits 和 log probabilities
outputs = model(x)
logits = outputs.logits
policy_log_probs = logits_to_log_probs(logits, y)
# 计算 DPO 损失和辅助损失 (MoE 负载均衡损失)
dpo_loss_val = dpo_loss(ref_log_probs, policy_log_probs, mask, beta=beta)
loss = dpo_loss_val + outputs.aux_loss
loss = loss / args.accumulation_steps
# 使用梯度缩放器进行反向传播, 支持混合精度训练
scaler.scale(loss).backward()
# 当累积达到指定步数时, 执行参数更新
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
@ -92,6 +148,7 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
scaler.update()
optimizer.zero_grad(set_to_none=True)
# 按指定间隔打印日志
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
@ -104,10 +161,12 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
if wandb: wandb.log({"loss": current_loss, "dpo_loss": current_dpo_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
# 按指定间隔保存模型检查点, 仅主进程执行
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
# 解包 DDP 包装的模型以获取原始模型
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
@ -116,17 +175,19 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
model.train()
del state_dict
# 删除中间变量以释放内存
del x_chosen, x_rejected, y_chosen, y_rejected, mask_chosen, mask_rejected, x, y, mask
del ref_outputs, ref_logits, ref_log_probs, outputs, logits, policy_log_probs, loss
if __name__ == "__main__":
# 创建参数解析器, 定义 DPO 训练的所有配置参数
parser = argparse.ArgumentParser(description="MiniMind DPO (Direct Preference Optimization)")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='dpo', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=4, help="batch size")
parser.add_argument("--learning_rate", type=float, default=4e-8, help="初始学习率(建议<=5e-8避免遗忘")
parser.add_argument("--learning_rate", type=float, default=4e-8, help="初始学习率 (建议<=5e-8避免遗忘)")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
@ -136,84 +197,113 @@ if __name__ == "__main__":
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度中文1token≈1.5~1.7字符)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度 (中文1token≈1.5~1.7字符)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构 (0=否, 1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/dpo.jsonl", help="DPO训练数据路径")
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训 (0=否, 1=是)")
parser.add_argument('--beta', default=0.1, type=float, help="DPO中的beta参数")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-DPO", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速 (0=否, 1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
# ========== 1.初始化环境和随机种子 ==========
# 初始化分布式训练模式, 返回本地 rank (非 DDP 模式下返回 0)
local_rank = init_distributed_mode()
# 如果处于分布式训练环境, 根据 local_rank 设置当前设备
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
# 设置随机种子, 分布式训练时根据 rank 偏移种子以确保不同进程使用不同的数据顺序
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
# ========== 2.配置目录、模型参数、检查检查点 ==========
# 创建保存目录 (如果不存在)
os.makedirs(args.save_dir, exist_ok=True)
# 创建 MiniMind 模型配置
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
# 如果启用续训模式, 尝试从检查点恢复训练状态
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
# ========== 3.设置混合精度训练 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
# 根据配置选择数据类型 (bfloat16 或 float16)
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# https://docs.pytorch.ac.cn/docs/stable/amp
# 新版废弃: autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=dtype)
# ========== 4. 配wandb ==========
# ========== 4.配置 wandb 日志 ==========
wandb = None
# 仅在主进程中初始化 wandb, 避免重复记录
if args.use_wandb and is_main_process():
import swanlab as wandb
# 从检查点恢复 wandb 运行 ID (如果存在)
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
# 构建运行名称, 包含关键训练参数
wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义模型和参考模型 ==========
# ========== 5.初始化模型和参考模型 ==========
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
# 如果启用 torch.compile, 使用 PyTorch 2.0 的编译加速功能
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
Logger(f'策略模型总参数量{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
# 初始化参考模型ref_model冻结
Logger(f'策略模型总参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
# 初始化参考模型 (ref_model 冻结, 不参与训练, 仅用于计算参考 logits)
ref_model, _ = init_model(lm_config, args.from_weight, device=args.device)
ref_model.eval()
ref_model.requires_grad_(False)
Logger(f'参考模型总参数量{sum(p.numel() for p in ref_model.parameters()) / 1e6:.3f} M')
Logger(f'参考模型总参数量: {sum(p.numel() for p in ref_model.parameters()) / 1e6:.3f} M')
# 创建 DPO 数据集和数据采样器
train_ds = DPODataset(args.data_path, tokenizer, max_length=args.max_seq_len)
# 分布式训练时使用 DistributedSampler 确保数据分配均匀
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
# 创建梯度缩放器, 仅在 float16 模式下启用 (用于混合精度训练)
# https://docs.pytorch.ac.cn/docs/stable/amp.html#gradient-scaling
# 新版废弃: scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
scaler = torch.amp.GradScaler("cuda", enabled=(args.dtype == 'float16'))
# 创建 AdamW 优化器
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
# ========== 6. 从ckp恢复状态 ==========
# ========== 6.从检查点恢复训练状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
# 从检查点加载模型权重、优化器状态、梯度缩放器状态
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
# 恢复训练进度
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. DDP包模型 ==========
# ========== 7.使用 DDP 模型 (分布式训练) ==========
if dist.is_initialized():
# 忽略位置编码相关的参数 (不需要梯度同步)
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
# 使用 DistributedDataParallel 包装模型, 启用分布式训练
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
# ========== 8.开始训练循环 ==========
for epoch in range(start_epoch, args.epochs):
# 分布式训练时设置 epoch 以确保不同进程使用不同的数据顺序
train_sampler and train_sampler.set_epoch(epoch)
# 设置当前 epoch 的随机种子, 生成随机索引序列
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
# 如果是恢复训练的第一个 epoch, 计算需要跳过的步数
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
# 创建支持跳过功能的批次采样器
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
# 创建数据加载器
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
# 执行当前 epoch 的训练
if skip > 0:
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}stepstep {start_step + 1}开始')
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step} step, 从 step {start_step + 1}开始')
train_epoch(epoch, loader, len(loader) + skip, ref_model, lm_config, start_step, wandb, args.beta)
else:
train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta)
# ========== 9. 清理分布进程 ==========
# ========== 9.清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -1,12 +1,13 @@
import os
import sys
# 设置包名
__package__ = "trainer"
# 将项目根目录添加到系统路径, 确保能够导入同级目录的模块
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import warnings
import torch
import torch.distributed as dist
from contextlib import nullcontext
@ -17,57 +18,112 @@ from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import SFTDataset
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
# 忽略所有警告信息, 避免训练过程中输出无关的警告日志
import warnings
warnings.filterwarnings('ignore')
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
"""
单个 epoch 的训练函数
执行一个完整的训练周期, 包括前向传播反向传播梯度更新日志打印和模型保存
Args:
- epoch: 当前训练轮次, 0 开始计数
- loader: 数据加载器, 提供训练数据批次
- iters: 当前 epoch 的总迭代次数
- start_step: 起始步数, 用于恢复训练时跳过已训练的批次, 默认为 0
- wandb: Weights & Biases 日志对象, None 时不记录日志
"""
# 记录 epoch 开始时间, 用于计算 ETA (预计剩余时间)
start_time = time.time()
# 遍历数据加载器中的所有批次
# enumerate 从 start_step + 1 开始, 保证恢复训练时 step 编号连续
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
# 将输入数据和标签移动到训练设备 (GPU 或 CPU)
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
# 根据余弦退火策略计算当前步的学习率
# 公式参数: 当前全局步数 (epoch * iters + step), 总步数 (epochs * iters), 基础学习率
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
# 更新优化器中所有参数组的学习率
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# 使用混合精度上下文进行前向传播
# 在 CUDA 设备上使用 bfloat16/float16, CPU 设备上不使用混合精度
with autocast_ctx:
# 模型前向传播, 输入 input_ids 和 labels, 返回 loss 和 aux_loss
res = model(input_ids, labels=labels)
# 总损失 = 主损失 (交叉熵) + 辅助损失 (MoE 负载均衡损失)
loss = res.loss + res.aux_loss
# 梯度累积: 将损失除以累积步数, 模拟更大 batch size 的训练效果
loss = loss / args.accumulation_steps
# 使用 GradScaler 进行反向传播, 支持混合精度训练的梯度缩放
scaler.scale(loss).backward()
# 每 accumulation_steps 步执行一次梯度更新
if (step + 1) % args.accumulation_steps == 0:
# 反缩放梯度, 将缩放后的梯度还原为真实梯度值
scaler.unscale_(optimizer)
# 梯度裁剪, 防止梯度爆炸, 限制梯度的 L2 范数不超过 grad_clip
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
# 执行优化器更新步骤
scaler.step(optimizer)
# 更新 GradScaler 的缩放因子, 根据梯度是否溢出动态调整
scaler.update()
# 清空梯度, set_to_none=True 可以节省内存
optimizer.zero_grad(set_to_none=True)
# 每 log_interval 步打印一次日志, 或在该 epoch 最后一步打印
if step % args.log_interval == 0 or step == iters - 1:
# 计算已花费的时间
spend_time = time.time() - start_time
# 还原真实的损失值 (之前被除以 accumulation_steps)
current_loss = loss.item() * args.accumulation_steps
# 获取 MoE 辅助损失, 如果不存在则为 0
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
# 计算主损失 (总损失减去辅助损失)
current_logits_loss = current_loss - current_aux_loss
# 获取当前学习率
current_lr = optimizer.param_groups[-1]['lr']
# 计算 ETA (预计剩余时间), 单位为分钟
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
# 打印训练日志
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
# 如果启用了 wandb, 记录指标
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
# 每 save_interval 步保存一次模型, 或在 epoch 最后一步保存, 仅在主进程中执行
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
# 切换到评估模式, 确保保存时模型处于稳定状态
model.eval()
# 根据是否使用 MoE 构建模型权重文件名后缀
moe_suffix = '_moe' if lm_config.use_moe else ''
# 构建权重保存路径
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
# 如果模型被 DDP 包装, 获取原始模型
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
# 如果模型被 torch.compile 编译, 获取原始模型
raw_model = getattr(raw_model, '_orig_mod', raw_model)
# 获取模型状态字典
state_dict = raw_model.state_dict()
# 将权重转换为半精度 (FP16) 并保存到 CPU, 然后保存到磁盘
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler)
# 保存完整的检查点 (包含优化器状态等), 用于恢复训练
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler)
# 切换回训练模式
model.train()
# 删除状态字典以释放内存
del state_dict
# 删除张量以释放 GPU 内存
del input_ids, labels, res, loss
@ -87,76 +143,112 @@ if __name__ == "__main__":
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度中文1token≈1.5~1.7字符)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度 (中文 1 token ≈ 1.5~1.7 字符)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否, 1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径")
parser.add_argument('--from_weight', default='pretrain', type=str, help="基于哪个权重训练为none则不基于任何权重训练")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument('--from_weight', default='pretrain', type=str, help="基于哪个权重训练, 为 none 则不基于任何权重训练")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否, 1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否, 1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
# ========== 1.初始化环境和随机种子 ==========
# 初始化分布式训练模式, 如果是分布式训练则返回 local_rank, 否则返回 0
local_rank = init_distributed_mode()
# 如果分布式环境已初始化, 根据 local_rank 设置当前进程使用的 GPU 设备
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
# 设置随机种子, 确保实验可复现, 不同 rank 使用不同种子避免数据同步问题
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
# ========== 2.配置目录、模型参数、检查ckp ==========
# 创建模型保存目录, 如果目录已存在则不报错
os.makedirs(args.save_dir, exist_ok=True)
# 创建模型配置对象, 包含隐藏层维度、层数、是否使用 MoE 等参数
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
# 如果启用自动恢复训练, 尝试从检查点加载数据, 否则返回 None
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
# ========== 3.设置混合精度 ==========
# 根据设备类型确定是 cuda 还是 cpu
device_type = "cuda" if "cuda" in args.device else "cpu"
# 根据配置选择混合精度的数据类型: bfloat16 或 float16
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
# https://docs.pytorch.ac.cn/docs/stable/amp
# 新版废弃: autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=dtype)
# ========== 4.配置 wandb ==========
wandb = None
# 仅在主进程中初始化 Weights & Biases 日志记录
if args.use_wandb and is_main_process():
import swanlab as wandb
# 如果有检查点数据, 获取之前的 wandb 运行 ID 用于恢复日志记录
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
# 如果有 wandb_id, 设置为必须恢复模式; 否则创建新运行
resume = 'must' if wandb_id else None
# 构建运行名称, 包含训练配置参数
wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
# 初始化 wandb 运行
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义模型、数据、优化器 ==========
# ========== 5.定义模型、数据、优化器 ==========
# 初始化模型和分词器, 可以选择从预训练权重加载
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
# 如果启用 torch.compile, 编译模型以加速训练
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
# 创建 SFT 数据集, 用于监督微调
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
# 如果启用分布式训练, 创建 DistributedSampler; 否则使用 None (后续会用普通索引)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
# 创建梯度缩放器, 仅在 float16 模式下启用 (用于混合精度训练)
# https://docs.pytorch.ac.cn/docs/stable/amp.html#gradient-scaling
# 新版废弃: scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
scaler = torch.amp.GradScaler("cuda", enabled=(args.dtype == 'float16'))
# 创建 AdamW 优化器
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
# ========== 6. 从ckp恢复状态 ==========
# ========== 6.从 ckp 恢复状态 ==========
# 初始化起始轮次和起始步数
start_epoch, start_step = 0, 0
# 如果有检查点数据, 恢复模型、优化器、scaler 的状态, 以及训练进度
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. DDP包模型 ==========
# ========== 7.DDP 包装模型 ==========
# 如果启用分布式训练, 使用 DistributedDataParallel 包装模型
if dist.is_initialized():
# 设置 DDP 忽略同步的参数: RoPE 的位置编码不需要梯度同步
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
# 使用 DDP 包装模型, 指定当前进程使用的 GPU 设备
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
# ========== 8.开始训练 ==========
# 从 start_epoch 开始训练, 直到 args.epochs
for epoch in range(start_epoch, args.epochs):
# 设置 DistributedSampler 的 epoch, 确保每个 epoch 的数据打乱方式不同
train_sampler and train_sampler.set_epoch(epoch)
# 为每个 epoch 设置不同的随机种子, 确保数据打乱顺序不同
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
# 如果是恢复训练的第一个 epoch, 需要跳过已经训练过的 step; 其他 epoch 从 0 开始
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
# 创建 SkipBatchSampler, 用于跳过指定数量的批次
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
# 创建 DataLoader, 使用自定义的 batch_sampler
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
if skip > 0:
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
# 如果需要跳过批次, 打印提示信息并调用 train_epoch 时传入正确的总步数和起始步数
if skip > 0:
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step, 从step {start_step + 1}开始')
train_epoch(epoch, loader, len(loader) + skip, start_step, wandb)
else:
train_epoch(epoch, loader, len(loader), 0, wandb)
# ========== 9. 清理分布进程 ==========
# 训练结束后, 如果启用了分布式训练, 销毁进程组
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -1,13 +1,14 @@
import os
import sys
# 设置包名
__package__ = "trainer"
# 将项目根目录添加到系统路径, 确保能够导入同级目录的模块
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import re
import gc
import warnings
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
@ -21,17 +22,43 @@ from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import RLAIFDataset
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model
# 忽略警告信息
import warnings
warnings.filterwarnings('ignore')
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
"""整合所有奖励函数计算总奖励"""
"""
整合所有奖励函数计算总奖励
根据配置使用 reasoning 模型奖励函数和/ reward model 评分来计算总奖励
支持推理模型的格式检查和普通模型的 reward model 评分
Args:
- prompts: 输入 prompt 列表, list[str], 长度为 batch_size
- responses: 模型生成的回复列表, list[str], 长度为 batch_size * num_generations
- reward_model: reward model 实例, 用于计算语义相似度分数
- reward_tokenizer: reward model 的分词器
Returns:
- torch.Tensor: 每个生成样本的奖励分数, shape [B*num_gen]
"""
def reasoning_model_reward(rewards):
"""
如果使用 reasoning 模型, 应用 reasoning 奖励函数
if args.reasoning == 1:
rewards = reasoning_model_reward(rewards)
"""
# 定义推理模型的格式匹配模式: <think> 块后跟 <answer> 块
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
# 备选模式: 允许 <think> 和 <answer> 之间有一个空行
pattern2 = r"^<think>\n.*?\n</think>\n\n<answer>\n.*?\n</answer>$"
# 检查每个回复是否匹配格式
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
# 计算格式奖励: 格式正确得 0.5 分, 否则 0 分
format_rewards = []
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
if match_pattern or match_pattern2:
@ -41,6 +68,7 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
rewards += torch.tensor(format_rewards, device=args.device)
def mark_num(text):
# 统计标签出现次数, 每个正确出现的标签奖励 0.25 分
reward = 0
if text.count("<think>") == 1: reward += 0.25
if text.count("</think>") == 1: reward += 0.25
@ -48,33 +76,42 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
if text.count("</answer>") == 1: reward += 0.25
return reward
# 计算标签数量奖励
mark_rewards = [mark_num(response) for response in responses]
rewards += torch.tensor(mark_rewards, device=args.device)
return rewards
# 初始化奖励张量为 0
rewards = torch.zeros(len(responses), device=args.device)
# 如果使用 reasoning 模型, 应用 reasoning 奖励函数
if args.reasoning == 1:
rewards = reasoning_model_reward(rewards)
# 使用 reward model 计算语义奖励
with torch.no_grad():
reward_model_scores = []
batch_size = len(prompts)
scale = 3.0
# 遍历每个 prompt 和每个生成的回复
for i in range(batch_size):
for j in range(args.num_generations):
response_idx = i * args.num_generations + j
response = responses[response_idx]
prompt = prompts[i]
# 解析 prompt 中的对话历史
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
matches = re.findall(pattern, prompt, re.DOTALL)
messages = [{"role": role, "content": content.strip()} for role, content in matches]
# 构建完整的对话上下文并计算 reward model 分数
tmp_chat = messages + [{"role": "assistant", "content": response}]
score = reward_model.get_score(reward_tokenizer, tmp_chat)
# 将分数裁剪到 [-scale, scale] 范围
score = max(min(score, scale), -scale)
# 对于 reasoning 模型, 额外计算 answer 部分的分数
if args.reasoning == 1:
answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
if answer_match:
@ -82,10 +119,12 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
answer_score = max(min(answer_score, scale), -scale)
# 综合分数: 40% 完整回复 + 60% answer 部分
score = score * 0.4 + answer_score * 0.6
reward_model_scores.append(score)
# 将 reward model 分数添加到总奖励
reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
rewards += reward_model_scores
@ -93,24 +132,58 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, start_step=0, wandb=None):
"""
GRPO (Group Relative Policy Optimization) epoch 训练函数
实现 GRPO 算法的核心训练循环, 包括:
- 生成多组回复样本
- 计算 reward advantage
- 计算 policy loss KL divergence
- 更新模型参数
Args:
- epoch: 当前训练轮次
- loader: 数据加载器
- iters: 总迭代次数
- ref_model: reference model (冻结参数), 用于计算 KL divergence
- reward_model: reward model, 用于计算奖励分数
- reward_tokenizer: reward model 的分词器
- start_step: 从第几步开始 (用于断点续训)
- wandb: 日志记录对象
"""
for step, batch in enumerate(loader, start=start_step + 1):
# 获取 prompt 列表并编码
prompts = batch['prompt'] # list[str], length B
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
prompt_inputs = tokenizer(
prompts,
return_tensors="pt",
padding=True,
return_token_type_ids=False,
padding_side="left",
add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
# 截断过长的 prompt
if args.max_seq_len:
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
# 使用当前 model 生成多个回复
with torch.no_grad():
# DDP 模型需要使用 .module 访问 generate 方法
model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
outputs = model_for_gen.generate(
**prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
num_return_sequences=args.num_generations, pad_token_id=tokenizer.pad_token_id) # [B*num_gen, P+R]
**prompt_inputs,
max_new_tokens=args.max_gen_len,
do_sample=True, temperature=0.8,
num_return_sequences=args.num_generations,
pad_token_id=tokenizer.pad_token_id) # [B*num_gen, P+R]
# 提取生成的回复部分 (去掉 prompt)
completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] # [B*num_gen, R]
def get_per_token_logps(mdl, input_ids, n_keep):
"""
获取每个 token log probability
"""
input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]
per_token_logps = []
@ -119,35 +192,44 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token
per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
return torch.stack(per_token_logps)
# 计算 policy model 的 per-token log probabilities
with autocast_ctx:
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B*num_gen, R]
# 如果使用 MoE, 计算辅助损失
res = model(outputs) if lm_config.use_moe else None
aux_loss = res.aux_loss if res is not None else torch.tensor(0.0, device=args.device)
# 计算 reference model 的 per-token log probabilities (无梯度)
with torch.no_grad():
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B*num_gen, R]
# 解码生成的回复并计算奖励
completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device) # [B*num_gen]
grouped_rewards = rewards.view(-1, args.num_generations) # [B, num_gen]
mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
# 按 group 分组计算 advantage (GRPO 核心)
grouped_rewards = rewards.view(-1, args.num_generations) # [B, num_gen]
mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
advantages = torch.clamp((rewards - mean_r) / (std_r + 1e-4), -10, 10)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # [B*num_gen]
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # [B*num_gen]
is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R]
# 构建 completion mask (处理 EOS 提前结束的情况)
is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R]
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int() # [B*num_gen, R]
# 计算 KL divergence
kl_div = ref_per_token_logps - per_token_logps
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R]
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R]
# 计算 GRPO loss: advantage-weighted loss - beta * KL penalty
per_token_loss = -(torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) - args.beta * per_token_kl) # [B*num_gen, R]
policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar
loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar
loss.backward()
# 梯度累积更新
if (step + 1) % args.accumulation_steps == 0:
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
@ -155,6 +237,7 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token
scheduler.step()
optimizer.zero_grad()
# 日志记录
if step % args.log_interval == 0 or step == iters:
policy_loss_val = loss.item() * args.accumulation_steps
current_aux_loss = aux_loss.item()
@ -176,19 +259,24 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token
"learning_rate": current_lr
})
# 模型保存检查点
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
# 解包 DDP 或 torch.compile 包装的模型
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
# 保存半精度权重到 CPU
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
# 保存完整的检查点 (用于断点续训)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
model.train()
del state_dict
# 清理中间变量释放内存
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
@ -209,37 +297,47 @@ if __name__ == "__main__":
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构 (0=否, 1=是)")
parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
parser.add_argument("--num_generations", type=int, default=8, help="每个prompt生成的样本数")
parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数")
parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型0=普通模型1=推理模型)')
parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型 (0=普通模型, 1=推理模型)')
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训 (0=否, 1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-GRPO", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速 (0=否, 1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
# ========== 1.初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
# 如果启用了分布式训练, 根据 local_rank 设置设备
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
# 设置随机种子, 分布式训练时根据 rank 偏移确保不同进程使用不同种子
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
# ========== 2.配置目录、模型参数、检查 ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
# 创建模型配置, max_seq_len 包含 prompt 和生成的最大长度
lm_config = MiniMindConfig(
hidden_size=args.hidden_size,
num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len + args.max_gen_len,
use_moe=bool(args.use_moe)
)
# 如果启用断点续训, 尝试加载检查点
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
# ========== 3.设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
# https://docs.pytorch.ac.cn/docs/stable/amp
# 新版废弃: autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=dtype)
# ========== 4.配置 wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
@ -248,49 +346,56 @@ if __name__ == "__main__":
wandb_run_name = f"MiniMind-GRPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 初始化模型和数据 ==========
# ========== 5.初始化模型和数据 ==========
# 根据是否为 reasoning 模型选择基础权重
base_weight = "reason" if args.reasoning == 1 else "full_sft"
# Policy模型
# Policy 模型 (将被训练的模型)
model, tokenizer = init_model(lm_config, base_weight, device=args.device)
# 如果启用 torch.compile 加速
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
# Reference模型
# Reference 模型 (冻结参数, 用于计算 KL divergence)
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
ref_model = ref_model.eval().requires_grad_(False)
# Reward模型
reward_model = AutoModel.from_pretrained(
args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True
)
# Reward 模型 (用于计算奖励分数, 冻结参数)
reward_model = AutoModel.from_pretrained(args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True)
reward_model = reward_model.to(args.device).eval().requires_grad_(False)
# Reward 模型分词器
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
# 数据和优化器
# 加载数据创建优化器
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
# 计算训练迭代次数和学习率调度器
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
iters = len(loader_for_count)
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
# ========== 6. 从ckp恢复状态 ==========
# ========== 6. ckp 恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
# 加载模型权重、优化器状态和学习率调度器状态
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scheduler.load_state_dict(ckp_data['scheduler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. DDP包模型 ==========
# ========== 7.DDP 模型 ==========
if dist.is_initialized():
# 忽略位置编码参数不参与 DDP 同步
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
# ========== 8.开始训练 ==========
for epoch in range(start_epoch, args.epochs):
# 分布式采样器设置 epoch 以确保不同进程数据不同
train_sampler and train_sampler.set_epoch(epoch)
# 每轮使用不同的随机种子打乱数据
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
# 如果是续训的第一个 epoch, 跳过已训练的 step
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
@ -300,5 +405,5 @@ if __name__ == "__main__":
else:
grpo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, 0, wandb)
# ========== 9. 清理分布进程 ==========
# ========== 9.清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -1,12 +1,13 @@
import os
import sys
# 设置包名
__package__ = "trainer"
# 将项目根目录添加到系统路径, 确保能够导入同级目录的模块
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import warnings
import torch
import torch.distributed as dist
from contextlib import nullcontext
@ -18,50 +19,95 @@ from dataset.lm_dataset import SFTDataset
from model.model_lora import save_lora, apply_lora
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
# 忽略所有警告信息, 保持输出整洁
import warnings
warnings.filterwarnings('ignore')
def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
"""
训练单个 epoch 的函数
执行 LoRA 微调的一个完整训练周期, 包括前向传播反向传播梯度更新和日志记录
Args:
- epoch: 当前训练轮次
- loader: 数据加载器, 提供训练数据批次
- iters: 当前 epoch 的总迭代次数
- lora_params: LoRA 可训练参数列表, 用于梯度裁剪
- start_step: 起始步数 (用于断点续训时跳过已训练部分)
- wandb: Weights & Biases 日志对象, 可选
"""
# 记录训练开始时间, 用于计算 ETA (预计完成时间)
start_time = time.time()
# 遍历数据加载器中的每个批次, 从 start_step + 1 开始计数
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
# 将输入数据移动到指定设备 (GPU 或 CPU)
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
# 使用余弦退火策略计算当前步的学习率
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
# 更新优化器中的学习率
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# 使用自动混合精度上下文进行前向传播 (仅在 GPU 时启用)
with autocast_ctx:
# 模型前向传播, 计算损失 (主损失 + 辅助损失, 如 MoE 负载均衡损失)
res = model(input_ids, labels=labels)
loss = res.loss + res.aux_loss
# 梯度累积: 将损失除以累积步数, 模拟大批量训练效果
loss = loss / args.accumulation_steps
# 使用梯度缩放器进行反向传播 (防止 FP16 梯度下溢)
scaler.scale(loss).backward()
# 每 accumulation_steps 步执行一次参数更新
if (step + 1) % args.accumulation_steps == 0:
# 反缩放梯度, 准备进行梯度裁剪
scaler.unscale_(optimizer)
# 对 LoRA 参数进行梯度裁剪, 防止梯度爆炸
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
# 执行优化器步骤, 更新参数
scaler.step(optimizer)
# 更新梯度缩放器的缩放因子
scaler.update()
# 清空梯度, 释放内存
optimizer.zero_grad(set_to_none=True)
# 每隔 log_interval 步或最后一步打印训练日志
if step % args.log_interval == 0 or step == iters - 1:
# 计算已花费时间
spend_time = time.time() - start_time
# 恢复原始损失值 (乘以累积步数)
current_loss = loss.item() * args.accumulation_steps
# 获取辅助损失值 (MoE 负载均衡损失)
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
# 计算主损失 (总损失减去辅助损失)
current_logits_loss = current_loss - current_aux_loss
# 获取当前学习率
current_lr = optimizer.param_groups[-1]['lr']
# 计算预计剩余时间 (ETA)
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
# 在主进程打印训练进度日志
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
# 如果使用 wandb, 记录训练指标
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
# 每隔 save_interval 步或最后一步保存模型检查点 (仅主进程执行)
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
# 切换到评估模式, 准备保存模型
model.eval()
# 构建 LoRA 权重保存路径
lora_save_path = f'{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}.pth'
# LoRA只保存LoRA权重
save_lora(model, lora_save_path)
# 保存完整的训练检查点 (包含模型、优化器、学习率调度器状态)
lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
# 切换回训练模式
model.train()
# 显式删除张量, 释放 GPU 内存
del input_ids, labels, res, loss
@ -78,98 +124,146 @@ if __name__ == "__main__":
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=10, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
parser.add_argument("--save_interval", type=int, default=1, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度中文1token≈1.5~1.7字符)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度 (中文 1 token ≈ 1.5~1.7 字符)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构 (0=否, 1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/lora_identity.jsonl", help="LoRA训练数据路径")
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练,默认full_sft")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练, 默认 full_sft")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训 (0=否, 1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速 (0=否, 1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
# ========== 1.初始化环境和随机种子 ==========
# 初始化分布式训练模式, 返回本地 GPU 编号
local_rank = init_distributed_mode()
# 如果启用了分布式训练, 将设备设置为当前进程对应的 GPU
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# 设置随机种子, 不同进程使用不同的种子以避免数据重复 (基础种子 42 + 进程 rank)
# ========== 2. 配置目录、模型参数、检查ckp ==========
# ========== 2.配置目录、模型参数、检查ckp ==========
# 确保模型保存目录存在
os.makedirs(args.save_dir, exist_ok=True)
# 创建 MiniMind 模型配置对象, 设置隐藏层维度、层数和是否使用 MoE
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
# 如果启用了断点续训, 尝试从检查点加载训练状态
ckp_data = lm_checkpoint(lm_config, weight=args.lora_name, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
# ========== 3.设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
# 根据参数选择数据类型: bfloat16 或 float16
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# https://docs.pytorch.ac.cn/docs/stable/amp
# 新版废弃: autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=dtype)
# ========== 4. 配wandb ==========
# ========== 4.配置 wandb ==========
# 初始化 wandb 对象为 None, 如果不启用则保持为 None
wandb = None
# 仅在主进程且启用 wandb 时初始化
if args.use_wandb and is_main_process():
# 导入 swanlab 作为 wandb 的替代 (可能是国内用户适配)
import swanlab as wandb
# 从检查点恢复 wandb 运行 ID (如果存在)
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
# 如果存在 wandb_id, 则必须恢复原有运行; 否则创建新运行
resume = 'must' if wandb_id else None
# 构建 wandb 运行名称, 包含关键训练参数
wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
# 初始化 wandb 运行
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义模型、应用LoRA、冻结非LoRA参数 ==========
# ========== 5.定义模型、应用LoRA、冻结非LoRA参数 ==========
# 初始化 MiniMind 模型和分词器, 从指定权重加载预训练参数
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
# 如果启用 torch.compile, 编译模型以加速训练 (PyTorch 2.0+ 特性)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
# 应用 LoRA (Low-Rank Adaptation) 到模型, 添加可训练的低秩矩阵
apply_lora(model)
# 统计参数
# 统计模型参数
total_params = sum(p.numel() for p in model.parameters())
# 统计 LoRA 相关参数量 (参数名中包含 'lora' 的参数)
lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name)
# 打印模型参数量统计信息
Logger(f"LLM 总参数量: {total_params / 1e6:.3f} M")
Logger(f"LoRA 参数量: {lora_params_count / 1e6:.3f} M")
Logger(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")
# 冻结非LoRA参数收集LoRA参数
# 冻结非 LoRA 参数, 收集 LoRA 参数
lora_params = []
for name, param in model.named_parameters():
if 'lora' in name:
# 启用 LoRA 参数的梯度计算
param.requires_grad = True
# 将 LoRA 参数添加到可训练参数列表
lora_params.append(param)
else:
# 冻结非 LoRA 参数 (原始模型参数不更新)
param.requires_grad = False
# ========== 6. 定义数据和优化器 ==========
# ========== 6.定义数据和优化器 ==========
# 创建指令微调数据集 (SFTDataset), 加载 JSONL 格式的训练数据
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
# 如果启用分布式训练, 使用 DistributedSampler 确保数据不重复; 否则为 None
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
# 创建梯度缩放器, 仅在 float16 模式下启用 (用于混合精度训练, 防止梯度下溢)
# https://docs.pytorch.ac.cn/docs/stable/amp.html#gradient-scaling
# 新版废弃: scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
scaler = torch.amp.GradScaler("cuda", enabled=(args.dtype == 'float16'))
# 创建 AdamW 优化器, 仅优化 LoRA 参数
optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
# ========== 7. 从ckp恢复状态 ==========
# ========== 7.从检查点恢复状态 ==========
# 初始化起始 epoch 和 step 为 0 (从头训练)
start_epoch, start_step = 0, 0
# 如果存在检查点数据, 恢复训练状态
if ckp_data:
# 加载模型权重 (strict=False 允许部分加载, 适应 LoRA 场景)
model.load_state_dict(ckp_data['model'], strict=False)
# 恢复优化器状态 (包括动量等)
optimizer.load_state_dict(ckp_data['optimizer'])
# 恢复梯度缩放器状态
scaler.load_state_dict(ckp_data['scaler'])
# 恢复训练轮次
start_epoch = ckp_data['epoch']
# 恢复训练步数 (默认为 0)
start_step = ckp_data.get('step', 0)
# ========== 8. DDP包模型 ==========
# ========== 8.DDP包装模型 ==========
# 如果启用分布式训练, 使用 DDP 包装模型
if dist.is_initialized():
# 设置 DDP 忽略同步的参数 (位置编码相关的缓存, 不需要梯度同步)
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
# 使用 DDP 包装模型, 指定当前进程使用的 GPU 设备
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 9. 开始训练 ==========
# ========== 9.开始训练 ==========
# 遍历从 start_epoch 到总 epoch 数的每个轮次
for epoch in range(start_epoch, args.epochs):
# 如果存在分布式采样器, 设置当前 epoch (确保数据打乱顺序在不同 epoch 不同)
train_sampler and train_sampler.set_epoch(epoch)
# 设置当前 epoch 的随机种子, 并生成随机打乱的索引列表
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
# 如果是恢复训练的第一个 epoch 且需要跳过部分 step, 设置 skip; 否则为 0
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
# 创建 SkipBatchSampler, 支持跳过指定数量的批次 (断点续训用)
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
# 创建 DataLoader, 使用自定义的 batch_sampler
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
# 如果需要跳过部分 step, 打印提示信息并调用 train_epoch
if skip > 0:
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
train_epoch(epoch, loader, len(loader) + skip, lora_params, start_step, wandb)
else:
# 正常训练, 从 step 0 开始
train_epoch(epoch, loader, len(loader), lora_params, 0, wandb)
# ========== 10. 清理分布进程 ==========
# ========== 10.清理分布式进程 ==========
# 如果启用了分布式训练, 销毁进程组, 释放资源
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -1,12 +1,13 @@
import os
import sys
# 设置包名
__package__ = "trainer"
# 将项目根目录添加到系统路径, 确保能够导入同级目录的模块
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import re
import warnings
import torch
import torch.distributed as dist
import torch.nn.functional as F
@ -22,35 +23,69 @@ from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import RLAIFDataset
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model
# 忽略警告信息, 避免输出干扰
import warnings
warnings.filterwarnings('ignore')
# 自定义的Critic模型继承自MiniMindLM
class CriticModel(MiniMindForCausalLM):
"""
Critic 模型, 继承自 MiniMindForCausalLM
用于估计状态价值, PPO 算法提供优势函数计算基础
"""
def __init__(self, params):
super().__init__(params)
# 替换lm_head为输出单一价值的线性层
# 添加价值头, 将隐藏层输出映射为单一标量价值
self.value_head = nn.Linear(params.hidden_size, 1)
def forward(self, input_ids=None, attention_mask=None, **kwargs):
# 使用基础模型获取隐藏状态
# 使用基础模型获取隐藏状态输出
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
# 对最后一层隐藏状态进行 LayerNorm 归一化
hidden_states = self.model.norm(outputs[0])
# 使用value_head获取价值估计
# 通过 value_head 计算每个位置的价值估计, 并移除最后一个维度
values = self.value_head(hidden_states).squeeze(-1)
return values
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
"""整合所有奖励函数计算总奖励"""
"""
整合所有奖励函数计算总奖励
Args:
- prompts: 输入提示列表, list[str]
- responses: 模型生成的响应列表, list[str]
- reward_model: 奖励模型, 用于评估回答质量
- reward_tokenizer: 奖励模型的分词器
Returns:
- torch.Tensor: 每个样本的总奖励值, shape [B]
"""
def reasoning_model_reward(rewards):
# 1. 格式奖励(仅针对训练推理模型时使用)
"""
推理模型的奖励计算
包含两部分奖励:
1. 格式奖励: 检查回答是否符合 <think>...</think><answer>...</answer> 格式
2. 标记奖励: 检查关键标记 (think, answer 标签) 的出现次数
Args:
- rewards: 当前奖励张量, shape [B]
Returns:
- torch.Tensor: 更新后的奖励张量
"""
# 定义正则表达式模式, 匹配标准的推理格式
# 模式1: <think> 和 <answer> 之间只有一个换行
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
# 模式2: <think> 和 <answer> 之间有两个换行
pattern2 = r"^<think>\n.*?\n</think>\n\n<answer>\n.*?\n</answer>$"
# 对每个响应进行正则匹配
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
# 计算格式奖励: 匹配任一模式得 0.5 分, 否则 0 分
format_rewards = []
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
if match_pattern:
@ -61,8 +96,9 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
format_rewards.append(0.0)
rewards += torch.tensor(format_rewards, device=args.device)
# 2. 标记奖励(防止严格奖励稀疏,仅针对训练推理模型时使用)
# 计算标记奖励, 防止严格格式奖励过于稀疏
def mark_num(text):
"""统计文本中关键标记的出现次数, 每类标记正确出现一次得 0.25 分"""
reward = 0
if text.count("<think>") == 1:
reward += 0.25
@ -78,38 +114,46 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
rewards += torch.tensor(mark_rewards, device=args.device)
return rewards
# 初始化奖励张量为零, shape [B]
rewards = torch.zeros(len(responses), device=args.device)
# 格式奖励
# 如果是推理模型, 添加格式和标记奖励
if args.reasoning == 1:
rewards = reasoning_model_reward(rewards)
# 使用reward model计算整个response的奖励
# 使用奖励模型计算 response 的整体质量分数
with torch.no_grad():
reward_model_scores = []
for prompt, response in zip(prompts, responses):
# 解析 prompt 中的对话历史
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
matches = re.findall(pattern, prompt, re.DOTALL)
messages = [{"role": role, "content": content.strip()} for role, content in matches]
# 构建完整的对话, 包含当前 response
tmp_chat = messages + [{"role": "assistant", "content": response}]
# 调用奖励模型获取评分
score = reward_model.get_score(reward_tokenizer, tmp_chat)
# 将分数裁剪到 [-scale, scale] 范围内, 避免极端值
scale = 3.0
score = max(min(score, scale), -scale)
# 当args.reasoning=1时额外计算<answer>内容的奖励
# 如果是推理模型, 额外计算 <answer> 内容部分的奖励
if args.reasoning == 1:
answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
if answer_match:
# 提取 answer 标签内的内容
answer_content = answer_match.group(1).strip()
# 对answer内容单独计算reward
# 对 answer 内容单独计算奖励
tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
answer_score = max(min(answer_score, scale), -scale)
# 综合评分: 整体回答 40% + answer 内容 60%
score = score * 0.4 + answer_score * 0.6
reward_model_scores.append(score)
# 将奖励模型分数转换为张量并累加到总奖励
reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
rewards += reward_model_scores
@ -117,81 +161,143 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, start_step=0, wandb=None):
"""
PPO epoch 训练函数
Args:
- epoch: 当前训练轮次
- loader: 数据加载器
- iters: 总迭代次数
- old_actor_model: 旧策略模型, 用于计算重要性采样比率
- ref_model: 参考模型, 用于计算 KL 散度惩罚
- actor_scheduler: Actor 学习率调度器
- critic_scheduler: Critic 学习率调度器
- reward_model: 奖励模型
- reward_tokenizer: 奖励模型分词器
- start_step: 起始步数 (用于断点续训)
- wandb: 日志记录工具
Note:
使用全局变量: actor_model, critic_model, tokenizer, lm_config, args, actor_optimizer, critic_optimizer, autocast_ctx
"""
# 设置模型为训练模式
actor_model.train()
critic_model.train()
for step, batch in enumerate(loader, start=start_step + 1):
# 获取批次中的 prompt 列表
prompts = batch["prompt"] # list[str], length B
enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True,
# 对 prompt 进行编码, 左填充以支持生成
enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True,
max_length=args.max_seq_len, padding_side="left").to(args.device) # input_ids: [B, P], attention_mask: [B, P]
prompt_length = enc.input_ids.shape[1]
# 使用当前策略生成响应 (不计算梯度)
with torch.no_grad():
# DDP 模型需要使用 .module 访问 generate 方法
# DDP 包装模型需要通过 .module 访问 generate 方法
model_for_gen = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
gen_out = model_for_gen.generate(
input_ids=enc.input_ids, attention_mask=enc.attention_mask,
max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) # [B, P+R]
# 解码生成的响应文本 (去除 prompt 部分)
responses_text = [tokenizer.decode(gen_out[i, prompt_length:], skip_special_tokens=True) for i in range(len(prompts))]
# 计算每个样本的奖励值
rewards = calculate_rewards(prompts, responses_text, reward_model, reward_tokenizer) # [B]
# 创建完整序列的注意力掩码 (非填充位置为 1)
full_mask = (gen_out != tokenizer.pad_token_id).long() # [B, P+R]
# 使用 Critic 模型评估生成序列的价值
values_seq = critic_model(input_ids=gen_out, attention_mask=full_mask) # [B, P+R]
# 找到每个序列最后一个非填充位置的索引
last_indices = (full_mask * torch.arange(full_mask.size(1), device=gen_out.device)).argmax(dim=1)
# 提取每个序列最终位置的价值估计
values = values_seq[torch.arange(values_seq.size(0), device=values_seq.device), last_indices] # [B]
# 计算优势函数: 奖励 - 价值 (使用 detach 阻止梯度流向 Critic)
advantages = rewards - values.detach() # [B]
# 使用混合精度计算 Actor 模型输出
with autocast_ctx:
res = actor_model(input_ids=gen_out, attention_mask=full_mask)
logits = res.logits # [B, P+R, V]
# 如果是 MoE 模型, 获取辅助损失; 否则为 0
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
# 准备标签 (向右移动一位, 用于计算每个位置的 log prob)
labels = gen_out[:, 1:].clone() # [B, P+R-1]
# 计算每个 token 的对数概率
logp_tokens = F.log_softmax(logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1) # [B, P+R-1]
seq_len = gen_out.size(1) - 1
# 创建响应部分的掩码 (只计算生成部分, 不计算 prompt)
resp_mask = torch.arange(seq_len, device=gen_out.device).unsqueeze(0) >= prompt_length - 1
# 最终掩码: 响应部分且非填充位置
final_mask = resp_mask & (~labels.eq(tokenizer.pad_token_id)) # [B, P+R-1]
# 计算 Actor 策略下整个响应的总对数概率
actor_logp = (logp_tokens * final_mask).sum(dim=1) # [B]
# 使用旧策略和参考模型计算对数概率 (不计算梯度)
with torch.no_grad():
# 旧策略模型的输出
old_logits = old_actor_model(input_ids=gen_out, attention_mask=full_mask).logits # [B, P+R, V]
old_logp_tokens = F.log_softmax(old_logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1) # [B, P+R-1]
old_logp = (old_logp_tokens * final_mask).sum(dim=1) # [B]
# 参考模型的输出 (用于 KL 散度计算)
ref_logits = ref_model(input_ids=gen_out, attention_mask=full_mask).logits # [B, P+R, V]
ref_logp_tokens = F.log_softmax(ref_logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1) # [B, P+R-1]
ref_logp = (ref_logp_tokens * final_mask).sum(dim=1) # [B]
# 计算与旧策略的 KL 散度
kl = (actor_logp - old_logp).mean() # scalar
# 计算与参考模型的 KL 散度 (作为惩罚项)
kl_ref = (actor_logp - ref_logp).mean() # scalar
# 计算重要性采样比率
ratio = torch.exp(actor_logp - old_logp) # [B]
# PPO 裁剪目标的第一项
surr1 = ratio * advantages # [B]
# PPO 裁剪目标的第二项 (裁剪比率防止过大更新)
surr2 = torch.clamp(ratio, 1.0 - args.clip_epsilon, 1.0 + args.clip_epsilon) * advantages # [B]
# 策略损失: 取裁剪后的最小值, 取负号后求平均 (因为我们要最大化目标)
policy_loss = -torch.min(surr1, surr2).mean() # scalar
# 价值损失: Critic 预测值与奖励的均方误差
value_loss = F.mse_loss(values, rewards) # scalar
# 总损失: 策略损失 + 价值系数 * 价值损失 + KL 系数 * KL 惩罚 + 辅助损失
loss = (policy_loss + args.vf_coef * value_loss + args.kl_coef * kl_ref + aux_loss) / args.accumulation_steps # scalar
# 反向传播计算梯度
loss.backward()
# 梯度累积: 达到指定步数后执行参数更新
if (step + 1) % args.accumulation_steps == 0:
# 对 Actor 和 Critic 的梯度进行裁剪, 防止梯度爆炸
clip_grad_norm_(actor_model.parameters(), args.grad_clip)
clip_grad_norm_(critic_model.parameters(), args.grad_clip)
# 执行优化器参数更新
actor_optimizer.step()
critic_optimizer.step()
# 更新学习率
actor_scheduler.step()
critic_scheduler.step()
# 清空梯度
actor_optimizer.zero_grad()
critic_optimizer.zero_grad()
# 只在主进程打印日志和记录指标
if is_main_process():
# 提取响应部分的 token IDs
response_ids = gen_out[:, enc.input_ids.shape[1]:]
# 检测每个响应是否包含 EOS 标记
is_eos = (response_ids == tokenizer.eos_token_id)
# 找到第一个 EOS 的位置
eos_indices = torch.argmax(is_eos.int(), dim=1)
# 检查每个响应是否实际包含 EOS
has_eos = is_eos.any(dim=1)
# 计算实际响应长度 (如果有 EOS 则为 EOS 位置+1, 否则为完整长度)
lengths = torch.where(has_eos, eos_indices + 1, torch.tensor(response_ids.shape[1], device=is_eos.device))
# 计算平均响应长度
avg_len = lengths.float().mean()
# 提取各项指标数值
actor_loss_val = policy_loss.item()
critic_loss_val = value_loss.item()
current_aux_loss = aux_loss.item()
@ -202,6 +308,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
actor_lr = actor_optimizer.param_groups[0]['lr']
critic_lr = critic_optimizer.param_groups[0]['lr']
# 记录到 wandb
if wandb is not None:
wandb.log({
"actor_loss": actor_loss_val,
@ -214,35 +321,45 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
"actor_lr": actor_lr,
})
# 打印训练日志
Logger(f"Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), "
f"Actor Loss: {actor_loss_val:.4f}, Critic Loss: {critic_loss_val:.4f}, Aux Loss: {current_aux_loss:.4f}, "
f"Reward: {reward_val:.4f}, KL: {kl_val:.4f}, KL_ref: {kl_ref_val:.4f}, "
f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.8f}, Critic LR: {critic_lr:.8f}")
# 按指定频率更新旧策略模型
if (step + 1) % args.update_old_actor_freq == 0:
# 解包 DDP 模型获取原始模型
raw_actor = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
raw_actor = getattr(raw_actor, '_orig_mod', raw_actor)
# 获取当前策略状态字典
state_dict = raw_actor.state_dict()
# 将参数 detached 后加载到旧策略模型
old_actor_model.load_state_dict({k: v.detach().cpu() for k, v in state_dict.items()})
old_actor_model.to(args.device)
# 按指定间隔保存检查点 (只在主进程执行)
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
actor_model.eval()
# 根据是否使用 MoE 构建文件名后缀
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
# 解包 DDP 模型获取原始 Actor 模型
raw_actor = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
raw_actor = getattr(raw_actor, '_orig_mod', raw_actor)
# 获取状态字典并转换为半精度后保存
actor_state = raw_actor.state_dict()
torch.save({k: v.half().cpu() for k, v in actor_state.items()}, ckp)
# 使用 lm_checkpoint 保存完整状态(包括 critic
lm_checkpoint(lm_config, weight=args.save_weight, model=actor_model, optimizer=actor_optimizer,
# 使用 lm_checkpoint 保存完整训练状态 (包括 Critic)
lm_checkpoint(lm_config, weight=args.save_weight, model=actor_model, optimizer=actor_optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints',
scheduler=actor_scheduler, critic_model=critic_model,
scheduler=actor_scheduler, critic_model=critic_model,
critic_optimizer=critic_optimizer, critic_scheduler=critic_scheduler)
actor_model.train()
del actor_state
# 显式删除中间变量, 释放显存
del enc, gen_out, responses_text, rewards, full_mask, values_seq, values, advantages
del logits, labels, logp_tokens, final_mask, actor_logp, old_logits, old_logp, ref_logits, ref_logp
del kl, kl_ref, ratio, surr1, surr2, policy_loss, value_loss, loss
@ -265,117 +382,174 @@ if __name__ == "__main__":
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构 (0=否, 1=是)")
parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
parser.add_argument("--clip_epsilon", type=float, default=0.1, help="PPO裁剪参数")
parser.add_argument("--vf_coef", type=float, default=0.5, help="Value function系数")
parser.add_argument("--kl_coef", type=float, default=0.02, help="KL散度惩罚系数")
parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型0=普通模型1=推理模型)')
parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型 (0=普通模型, 1=推理模型)')
parser.add_argument("--update_old_actor_freq", type=int, default=4, help="更新old_actor_model的频率")
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训 (0=否, 1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-PPO", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速 (0=否, 1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
# ========== 1.初始化环境和随机种子 ==========
# 初始化分布式训练环境, 获取本地 rank
local_rank = init_distributed_mode()
# 如果处于分布式训练环境, 根据 local_rank 设置设备
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
# 设置随机种子, 分布式环境下根据 rank 添加偏移以确保不同进程种子不同
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
# ========== 2.配置目录、模型参数、检查 ckp ==========
# 创建模型保存目录
os.makedirs(args.save_dir, exist_ok=True)
# 创建模型配置对象
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
# 如果启用断点续训, 尝试加载检查点数据
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
# ========== 3.设置混合精度 ==========
# 判断设备类型
device_type = "cuda" if "cuda" in args.device else "cpu"
# 设置数据类型
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
# https://docs.pytorch.ac.cn/docs/stable/amp
# 新版废弃: autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=dtype)
# ========== 4. 配置 wandb ==========
wandb = None
if args.use_wandb and is_main_process():
# 使用 swanlab 作为 wandb 的替代
import swanlab as wandb
# 尝试从检查点恢复 wandb ID
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
# 构建运行名称
wandb_run_name = f"MiniMind-PPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
# 初始化 wandb
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 初始化模型和数据 ==========
# 根据是否训练推理模型选择基础权重名称
base_weight = "reason" if args.reasoning == 1 else "full_sft"
# Actor模型
# Actor 模型 (策略网络, 需要训练)
actor_model, tokenizer = init_model(lm_config, base_weight, device=args.device)
if args.use_compile == 1:
# 使用 torch.compile 加速模型 (需要 PyTorch 2.0+)
actor_model = torch.compile(actor_model)
Logger('torch.compile enabled')
# Old Actor模型
# Old Actor 模型 (旧策略, 用于 PPO 的重要性采样)
old_actor_model, _ = init_model(lm_config, base_weight, device=args.device)
# 设置为评估模式, 不计算梯度
old_actor_model = old_actor_model.eval().requires_grad_(False)
# Reference模型
# Reference 模型 (参考模型, 用于 KL 散度惩罚)
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
ref_model = ref_model.eval().requires_grad_(False)
# Critic模型
# Critic 模型 (价值网络, 估计状态价值)
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
# 加载基础模型权重初始化 Critic
state_dict = torch.load(ckp, map_location=args.device)
critic_model = CriticModel(lm_config)
critic_model.load_state_dict(state_dict, strict=False)
critic_model = critic_model.to(args.device)
# Reward模型
reward_model = AutoModel.from_pretrained(
args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True
)
# Reward 模型 (奖励模型, 提供奖励信号)
reward_model = AutoModel.from_pretrained(args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True)
reward_model = reward_model.to(args.device).eval().requires_grad_(False)
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
# 数据和优化器
# 数据集和优化器配置
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=(args.max_seq_len + args.max_gen_len))
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
actor_optimizer = optim.AdamW(actor_model.parameters(), lr=args.learning_rate)
critic_optimizer = optim.AdamW(critic_model.parameters(), lr=args.critic_learning_rate)
# 计算训练迭代次数
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
iters = len(loader_for_count)
# 计算总优化步数 (考虑梯度累积)
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
# 学习率调度器 (余弦退火)
actor_scheduler = CosineAnnealingLR(actor_optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
critic_scheduler = CosineAnnealingLR(critic_optimizer, T_max=total_optimizer_steps, eta_min=args.critic_learning_rate / 10)
# ========== 6. 从ckp恢复状态 ==========
# ==========6. 从ckp 恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
# 恢复模型权重
actor_model.load_state_dict(ckp_data['model'])
critic_model.load_state_dict(ckp_data['critic_model'])
# 恢复优化器状态
actor_optimizer.load_state_dict(ckp_data['optimizer'])
critic_optimizer.load_state_dict(ckp_data['critic_optimizer'])
# 恢复学习率调度器状态
actor_scheduler.load_state_dict(ckp_data['scheduler'])
critic_scheduler.load_state_dict(ckp_data['critic_scheduler'])
# 恢复训练进度
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. DDP包模型 ==========
# ========== 7.DDP 模型 ==========
if dist.is_initialized():
# 设置 DDP 忽略的参数 (位置编码相关, 不需要梯度同步)
actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
# 包装为 DDP 模型
actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank])
critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank])
# 确保旧策略模型在正确设备上
old_actor_model.to(args.device)
# ========== 8. 开始训练 ==========
# ========== 8.开始训练 ==========
for epoch in range(start_epoch, args.epochs):
# 设置分布式采样器的 epoch (确保每个 epoch 数据打乱方式不同)
train_sampler and train_sampler.set_epoch(epoch)
# 每个 epoch 使用不同的随机种子
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
# 如果是续训的第一个 epoch, 需要跳过已经训练过的 step
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
# 创建支持跳过指定批次数量的采样器
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
# 创建数据加载器
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
if skip > 0:
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
ppo_train_epoch(epoch, loader, len(loader) + skip, old_actor_model, ref_model,
actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, start_step, wandb)
if skip > 0:
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step, 从step {start_step + 1}开始')
ppo_train_epoch(
epoch,
loader,
len(loader) + skip,
old_actor_model,
ref_model,
actor_scheduler,
critic_scheduler,
reward_model,
reward_tokenizer,
start_step,
wandb)
else:
ppo_train_epoch(epoch, loader, len(loader), old_actor_model, ref_model,
actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, 0, wandb)
# ========== 9. 清理分布进程 ==========
ppo_train_epoch(
epoch,
loader,
len(loader),
old_actor_model,
ref_model,
actor_scheduler,
critic_scheduler,
reward_model,
reward_tokenizer,
0,
wandb)
# ========== 9.清理分布式进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -1,72 +1,131 @@
import os
import sys
# 设置包名
__package__ = "trainer"
# 将项目根目录添加到系统路径, 确保能够导入同级目录的模块
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import warnings
import torch
# 分布式训练支持
import torch.distributed as dist
# 上下文管理器
# https://docs.python.org/zh-cn/3.14/library/contextlib.html#contextlib.nullcontext
from contextlib import nullcontext
from torch import optim, nn
# 分布式数据并行 DDP
from torch.nn.parallel import DistributedDataParallel
# 数据加载器和分布式采样器
from torch.utils.data import DataLoader, DistributedSampler
from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import PretrainDataset
# 预训练数据集
from dataset.lm_dataset import PretrainDataset
# 导入训练工具函数:学习率计算、日志记录、检查点保存、分布式初始化等
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
# 忽略所有警告信息,减少输出干扰
import warnings
warnings.filterwarnings('ignore')
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
"""
执行单个训练 epoch
Args:
- epoch: 当前 epoch 编号
- loader: 数据加载器
- iters: 当前 epoch 的总迭代次数
- start_step: 起始 step 编号 (用于断点续训)
- wandb: 实验日志工具实例
"""
# 记录当前 epoch 开始时间
start_time = time.time()
# 遍历数据加载器, 获取批次数据
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
# 将 数据 & 标签 移动到指定设备
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
# 根据当前 step 计算当前学习率
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
# 更新优化器中所有参数组的学习率
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# 使用混合精度上下文进行前向传播
# - 在 __name__ == "__main__" 中定义
# - autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(dtype=dtype)
with autocast_ctx:
# 模型前向传播, 其中包含损失
res = model(input_ids, labels=labels)
# 合并主损失和辅助损失 (如 MoE 的负载均衡损失)
loss = res.loss + res.aux_loss
# 根据梯度累积步数进行缩放
loss = loss / args.accumulation_steps
# 反向传播计算梯度
# - 在 __name__ == "__main__" 中定义
# - scaler = torch.amp.GradScaler("cuda", enabled=(args.dtype == 'float16'))
scaler.scale(loss).backward()
# 当达到梯度累积步数时,执行优化器更新
if (step + 1) % args.accumulation_steps == 0:
# 取消梯度缩放, 为优化器步骤做准备
scaler.unscale_(optimizer)
# 对梯度进行裁剪, 防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
# 执行优化器步骤更新参数
scaler.step(optimizer)
# 更新缩放器状态
scaler.update()
# 清空梯度, 设置为 None 以节省内存
optimizer.zero_grad(set_to_none=True)
# 达到日志打印间隔时, 记录训练状态
if step % args.log_interval == 0 or step == iters - 1:
# 计算已消耗时间
spend_time = time.time() - start_time
# 获取当前损失值
current_loss = loss.item() * args.accumulation_steps
# 获取辅助损失值
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
# 计算 logits 损失 (主损失)
current_logits_loss = current_loss - current_aux_loss
# 获取当前学习率
current_lr = optimizer.param_groups[-1]['lr']
# 估算剩余时间 (分钟)
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
# 打印训练日志
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
# 记录到 wandb (如果启用)
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
# 达到保存间隔且为主进程时, 保存模型检查点
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
# 设置为评估模式
model.eval()
# 生成模型文件名后缀, 区分 MoE 和普通模型
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
# 获取原始模型 (移除 DDP 包装)
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
# 获取模型状态字典
state_dict = raw_model.state_dict()
# 保存模型权重 (转换为半精度以节省空间)
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
# 保存完整训练检查点 (包含配置、优化器状态等)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
# 恢复训练模式
model.train()
# 释放显存
del state_dict
# 清理当前批次的数据, 释放显存
del input_ids, labels, res, loss
@ -74,7 +133,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Pretraining")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='pretrain', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=1, help="训练轮数建议1轮zero或2-6轮充分训练")
parser.add_argument("--epochs", type=int, default=1, help="训练轮数 (建议 1 轮 zero 或 2-6 轮充分训练)")
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
@ -86,52 +145,76 @@ if __name__ == "__main__":
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度中文1token≈1.5~1.7字符)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度 (中文 1token≈1.5~1.7 字符)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用 MoE 架构 (0=否, 1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl", help="预训练数据路径")
parser.add_argument('--from_weight', default='none', type=str, help="基于哪个权重训练为none则从头开始")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
parser.add_argument('--from_weight', default='none', type=str, help="基于哪个权重训练, 为 none 则从头开始")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测 & 续训 (0=否, 1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用 wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb 项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用 torch.compile 加速 (0=否, 1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
# ========== 1.初始化环境和随机种子 ===========
# 初始化分布式训练模式, 返回本地 GPU 编号
local_rank = init_distributed_mode()
# 如果已初始化分布式环境, 更新设备为对应 GPU
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
# 设置随机种子, 确保实验可复现 (不同进程使用不同种子避免同步)
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
# ========== 2.配置目录、模型参数、检查 ckp ===========
# 创建保存目录 (如果不存在)
os.makedirs(args.save_dir, exist_ok=True)
# 根据命令行参数创建 MiniMind 模型配置
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
# 如果启用断点续训, 尝试加载已有检查点信息
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
# ========== 3.设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
# 根据参数选择精度类型 (bfloat16 或 float16)
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# https://docs.pytorch.ac.cn/docs/stable/amp
# 新版废弃: autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=dtype)
# ========== 4. 配wandb ==========
# ========== 4.配置 wandb ==========
wandb = None
# 仅主进程启用 wandb, 避免日志重复
if args.use_wandb and is_main_process():
# 导入日志工具
import swanlab as wandb
# 从检查点获取 wandb run ID (用于恢复实验)
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
# 设置恢复模式:如果有 ID 则恢复之前实验, 否则创建新实验
resume = 'must' if wandb_id else None
# 生成实验名称, 包含关键超参数
wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
# 初始化 wandb 实验
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义模型、数据、优化器 ==========
# ========== 5.定义模型、数据、优化器 ==========
# 初始化模型和分词器
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
# 如果启用 torch.compile, 对模型进行编译优化 (需要 PyTorch 2.0+ & Linux)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
# 创建预训练数据集
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
# 如果使用分布式训练, 创建分布式采样器
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
# 创建梯度缩放器 (用于混合精度训练)
# https://docs.pytorch.ac.cn/docs/stable/amp.html#gradient-scaling
# 新版废弃: scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
scaler = torch.amp.GradScaler("cuda", enabled=(args.dtype == 'float16'))
# 创建 AdamW 优化器
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
# ========== 6. 从ckp恢复状态 ==========
# ========== 6. ckp恢复状态 ==========
start_epoch, start_step = 0, 0
# 如果有检查点数据, 恢复训练状态
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
@ -139,23 +222,33 @@ if __name__ == "__main__":
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. DDP包模型 ==========
# ========== 7.DDP 包装模型 ==========
# 如果使用分布式训练, 使用包装模型为 DDP 模式
if dist.is_initialized():
# 忽略频率向量 (避免 DDP 重复广播)
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
# ========== 8.开始训练 ==========
# 遍历所有 epoch
for epoch in range(start_epoch, args.epochs):
# 如果使用分布式采样器, 设置 epoch 以确保数据打乱
train_sampler and train_sampler.set_epoch(epoch)
# 设置随机种子并生成打乱的索引顺序
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
# 计算需要跳过的 step 数 (断点续训时)
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
# 创建批次采样器, 支持跳过指定数量的批次
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
# 创建数据加载器
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
# 如果有跳过, 跳过前 start_step 个 step
if skip > 0:
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个stepstep {start_step + 1}开始')
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前 {start_step} 个 step, 从 step {start_step + 1} 开始')
train_epoch(epoch, loader, len(loader) + skip, start_step, wandb)
else:
train_epoch(epoch, loader, len(loader), 0, wandb)
# ========== 9. 清理分布进程 ==========
# ========== 9.清理分布进程 ==========
# 销毁分布式进程组, 释放资源
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -1,12 +1,14 @@
# 导入标准库
import os
import sys
# 设置包名, 确保能够正确导入同级目录的模块
__package__ = "trainer"
# 将项目根目录添加到系统路径
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import warnings
import torch
import torch.distributed as dist
from contextlib import nullcontext
@ -17,74 +19,138 @@ from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import SFTDataset
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
# 忽略警告信息, 减少控制台输出干扰
import warnings
warnings.filterwarnings('ignore')
def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=None):
start_of_think_ids = tokenizer('<think>').input_ids
"""
单个 epoch 的训练函数
执行模型的推理蒸馏训练, 包括前向传播损失计算梯度累积和反向传播
使用特殊 token (<think>, </think>, <answer>, </answer>) 对思考过程和答案部分进行加权
Args:
- epoch: 当前训练轮次
- loader: 数据加载器, 提供 (input_ids, labels) 批次
- iters: 当前 epoch 的总迭代步数
- tokenizer: 分词器对象, 用于获取特殊 token ID
- lm_config: 模型配置对象, 包含 use_moe 等属性
- start_step: 起始步数, 用于断点续训时跳过已训练的步数
- wandb: Weights & Biases 日志对象, 可选
"""
# 获取特殊 token 的 ID, 用于在损失计算中对思考过程加权
# <think> 和 </think> 标记思考过程的开始和结束
start_of_think_ids = tokenizer(' <think>').input_ids
end_of_think_ids = tokenizer('</think>').input_ids
# <answer> 和 </answer> 标记答案的开始和结束
start_of_answer_ids = tokenizer('<answer>').input_ids
end_of_answer_ids = tokenizer('</answer>').input_ids
# 使用不 reduction 的交叉熵损失, 以便后续应用自定义 loss mask
loss_fct = nn.CrossEntropyLoss(reduction='none')
# 记录 epoch 开始时间, 用于计算训练速度和 ETA
start_time = time.time()
# 遍历数据加载器, 从 start_step + 1 开始计数
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
# 将输入数据移动到训练设备
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
# 计算当前步的学习率, 使用余弦退火策略
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
# 更新优化器的学习率
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# 使用自动混合精度上下文进行前向传播
with autocast_ctx:
# 模型前向传播, 获取输出结果
res = model(input_ids)
# 对 logits 和 labels 进行移位, 以进行因果语言建模
# shift_logits: 预测下一个 token 的 logits
# shift_labels: 实际的下一个 token 的标签
shift_logits = res.logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# 计算每个位置的交叉熵损失
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size())
# 构建损失掩码, 标记有效标签位置 (非 -100 的 padding)
loss_mask = (shift_labels != -100).float()
sp_ids = torch.isin(shift_labels.view(-1),
torch.tensor(start_of_think_ids + end_of_think_ids
+ start_of_answer_ids + end_of_answer_ids
).to(args.device))
# 找出特殊 token 的位置 (思考过程和答案的标记)
sp_ids = torch.isin(
shift_labels.view(-1),
torch.tensor(start_of_think_ids + end_of_think_ids + start_of_answer_ids + end_of_answer_ids).to(args.device)
)
# 对 loss mask 进行加权: 特殊 token 位置权重设为 10, 增强对推理过程的学习
loss_mask_flat = loss_mask.view(-1)
loss_mask_sum = loss_mask_flat.sum()
loss_mask_flat[sp_ids] = 10
loss_mask = loss_mask_flat.view(shift_labels.size())
# 计算加权后的 logits 损失
logits_loss = (loss * loss_mask).sum() / loss_mask_sum
# 总损失 = 加权 logits 损失 + MoE 辅助损失 (如使用 MoE)
loss = logits_loss + res.aux_loss
# 梯度累积: 将损失除以累积步数, 模拟大批量训练
loss = loss / args.accumulation_steps
# 使用梯度缩放器进行反向传播, 防止 FP16/BF16 梯度下溢
scaler.scale(loss).backward()
# 当达到梯度累积步数时, 执行优化器更新
if (step + 1) % args.accumulation_steps == 0:
# 取消梯度缩放, 以便进行梯度裁剪
scaler.unscale_(optimizer)
# 梯度裁剪, 防止梯度爆炸, 保持训练稳定
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
# 执行优化器更新
scaler.step(optimizer)
# 更新梯度缩放器的缩放因子
scaler.update()
# 清空梯度, set_to_none=True 节省内存
optimizer.zero_grad(set_to_none=True)
# 按日志间隔打印训练状态, 或在最后一个 step 打印
if step % args.log_interval == 0 or step == iters - 1:
# 计算已花费的时间和各项损失值
spend_time = time.time() - start_time
# 恢复原始损失值 (之前除以了 accumulation_steps)
current_loss = loss.item() * args.accumulation_steps
# 获取 MoE 辅助损失 (如果存在)
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
current_logits_loss = logits_loss.item()
current_lr = optimizer.param_groups[-1]['lr']
# 计算预计剩余时间 (分钟)
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
# 打印训练日志
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
# 如果使用 wandb, 记录训练指标
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
# 按保存间隔保存模型检查点, 仅在主进程执行
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
# 切换到评估模式, 准备保存模型
model.eval()
# 根据是否使用 MoE 确定文件名后缀
moe_suffix = '_moe' if lm_config.use_moe else ''
# 构建模型保存路径
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
# 如果是 DDP 模型, 解包获取原始模型
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
# 处理 torch.compile 包装的情况
raw_model = getattr(raw_model, '_orig_mod', raw_model)
# 获取模型状态字典
state_dict = raw_model.state_dict()
# 将权重转换为半精度 (FP16) 并保存到 CPU, 减小文件大小
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
# 保存完整检查点 (包含优化器状态等, 用于断点续训)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
# 切换回训练模式
model.train()
# 释放状态字典内存
del state_dict
# 删除批次数据, 释放 GPU 内存
del input_ids, labels, res, loss
@ -104,76 +170,110 @@ if __name__ == "__main__":
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=720, type=int, help="训练的最大截断长度中文1token≈1.5~1.7字符)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--max_seq_len', default=720, type=int, help="训练的最大截断长度 (中文 1 token ≈ 1.5 ~ 1.7 字符)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构 (0=否, 1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/r1_mix_1024.jsonl", help="推理蒸馏数据路径")
parser.add_argument('--from_weight', default='dpo', type=str, help="基于哪个权重训练默认dpo")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument('--from_weight', default='dpo', type=str, help="基于哪个权重训练, 默认dpo")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训 (0=否, 1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Reasoning", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速 (0=否, 1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
# ========== 1.初始化分布式环境和随机种子 ==========
# 初始化分布式训练模式, 获取本地 rank
local_rank = init_distributed_mode()
# 如果分布式训练已初始化, 根据 local_rank 设置设备
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
# 设置随机种子, 分布式环境下每个 rank 使用不同的种子
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
# ========== 2.配置目录、模型参数、检查检查点 ==========
# 创建保存目录 (如果不存在)
os.makedirs(args.save_dir, exist_ok=True)
# 初始化模型配置
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
# 如果启用断点续训, 尝试加载检查点
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
# ========== 3.设置混合精度训练 ==========
# 确定设备类型
device_type = "cuda" if "cuda" in args.device else "cpu"
# 根据参数设置数据类型 (BF16 或 FP16)
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
# https://docs.pytorch.ac.cn/docs/stable/amp
# 新版废弃: autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=dtype)
# ========== 4.配置 wandb 日志 ==========
wandb = None
# 仅在主进程且启用 wandb 时初始化
if args.use_wandb and is_main_process():
import swanlab as wandb
# 断点续训时恢复 wandb 运行 ID
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
# 构建运行名称
wandb_run_name = f"MiniMind-Reasoning-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义模型、数据、优化器 ==========
# ========== 5.定义模型、数据集、优化器 ==========
# 初始化模型和分词器
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
# 如果使用 torch.compile 加速 (PyTorch 2.0+ 特性)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
# 创建训练数据集
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
# 如果是分布式训练, 创建分布式采样器
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
# 创建梯度缩放器, 仅在 float16 模式下启用 (用于混合精度训练, 防止梯度下溢)
# https://docs.pytorch.ac.cn/docs/stable/amp.html#gradient-scaling
# 新版废弃: scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
scaler = torch.amp.GradScaler("cuda", enabled=(args.dtype == 'float16'))
# 创建 AdamW 优化器
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
# ========== 6. 从ckp恢复状态 ==========
# ========== 6.从检查点恢复训练状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
# 加载模型权重
model.load_state_dict(ckp_data['model'])
# 加载优化器状态
optimizer.load_state_dict(ckp_data['optimizer'])
# 加载梯度缩放器状态
scaler.load_state_dict(ckp_data['scaler'])
# 恢复训练轮次和步数
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. DDP包模型 ==========
# ========== 7.DDP 包装模型 ==========
# 如果使用分布式训练, 用 DistributedDataParallel 包装模型
if dist.is_initialized():
# 设置不需要同步的参数 (RoPE 位置编码的频率参数)
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
# ========== 8.开始训练 ==========
for epoch in range(start_epoch, args.epochs):
# 分布式采样器设置 epoch, 确保每个 epoch 的数据打乱顺序不同
train_sampler and train_sampler.set_epoch(epoch)
# 设置当前 epoch 的随机种子, 并打乱数据顺序
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
# 如果是恢复训练的第一个 epoch, 计算需要跳过的步数
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
# 创建支持跳过批次的采样器
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
# 创建数据加载器
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
if skip > 0:
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
# 执行训练 epoch
if skip > 0:
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step, 从step {start_step + 1}开始')
train_epoch(epoch, loader, len(loader) + skip, tokenizer, lm_config, start_step, wandb)
else:
train_epoch(epoch, loader, len(loader), tokenizer, lm_config, 0, wandb)
# ========== 9. 清理分布进程 ==========
# ========== 9.清理分布式进程 ==========
# 训练完成后销毁进程组, 释放资源
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -1,13 +1,14 @@
import os
import sys
# 设置包名, 用于相对导入
__package__ = "trainer"
# 将项目根目录添加到系统路径, 确保能够导入同级目录的模块
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import re
import gc
import warnings
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
@ -21,36 +22,84 @@ from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import RLAIFDataset
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model
# 忽略警告信息, 保持输出整洁
import warnings
warnings.filterwarnings('ignore')
class AutoAdaptiveValueTracker:
"""SPO自适应价值追踪器"""
"""
SPO 自适应价值追踪器
使用 Beta 分布来估计奖励基线, 支持基于 KL 散度的自适应 rho 更新
"""
def __init__(self, rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96):
# rho 计算模式: 'kl' 表示基于 KL 散度, 'constant' 表示固定值
self.rho_mode = rho_mode
# 固定 rho 模式下的常数值
self.rho_const = rho_const
# KL 散度半衰期参数, 用于控制 rho 随 KL 变化的敏感度
self.D_half = D_half
# rho 值的裁剪上下限, 防止更新过快或过慢
self.clip_lower = clip_lower
self.clip_upper = clip_upper
# 初始化 Beta 分布的参数 alpha 和 beta
N_init = 1.0 / (1.0 - self.clip_lower)
self.alpha = 0.5 * N_init
self.beta = 0.5 * N_init
# 保存上一轮迭代的平均对数概率, 用于计算 KL 散度
self.old_mean_logprob = None
def get_baselines(self, batch_size):
"""
计算当前批次样本的基线值
Args:
- batch_size: 批次大小
Returns:
- Tensor: 形状为 (batch_size,) 的基线值张量
"""
# Beta 分布的期望值作为基线
baseline = self.alpha / (self.alpha + self.beta)
return torch.full((batch_size,), baseline, dtype=torch.float32)
def compute_rho(self, cur_mean_logprob):
"""
计算自适应 rho
Args:
- cur_mean_logprob: 当前平均对数概率
Returns:
- float: rho , 介于 clip_lower clip_upper 之间
"""
# 固定模式下返回常量
if self.rho_mode == 'constant':
return self.rho_const
# 首次迭代时返回默认值
if self.old_mean_logprob is None:
return self.rho_const
# 计算当前与上一轮的对数概率差值 (KL 散度的近似)
kl = abs(self.old_mean_logprob - cur_mean_logprob)
# 使用指数衰减公式计算 rho
rho = 2 ** (-kl / self.D_half)
# 裁剪到合理范围内
return max(min(rho, self.clip_upper), self.clip_lower)
def update(self, rewards, cur_logprobs=None, response_masks=None):
"""
根据奖励值更新 Beta 分布参数
Args:
- rewards: 奖励值张量, 形状为 (batch_size,)
- cur_logprobs: 当前策略的对数概率 (可选)
- response_masks: 响应序列的掩码 (可选)
Returns:
- float: 本次更新使用的 rho
"""
# 如果提供了对数概率和掩码, 计算平均对数概率并更新 rho
if cur_logprobs is not None and response_masks is not None:
mean_logprob = ((cur_logprobs * response_masks).sum() / response_masks.sum()).item()
rho = self.compute_rho(mean_logprob)
@ -58,22 +107,50 @@ class AutoAdaptiveValueTracker:
else:
rho = self.rho_const
# 将奖励值归一化到 [0, 1] 区间, 与 Beta 分布匹配
scale = 3.0
normalized_rewards = (rewards + scale) / (2 * scale)
avg_normalized_reward = normalized_rewards.mean().item()
# 使用 rho 进行指数移动平均更新 Beta 分布参数
self.alpha = rho * self.alpha + avg_normalized_reward
self.beta = rho * self.beta + (1 - avg_normalized_reward)
return rho
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
"""整合所有奖励函数计算总奖励"""
"""
整合所有奖励函数计算总奖励
Args:
- prompts: 输入提示列表, 字符串列表
- responses: 模型生成的响应列表, 字符串列表
- reward_model: 奖励模型, 用于评估响应质量
- reward_tokenizer: 奖励模型的分词器
Returns:
- Tensor: 每个响应的奖励值, 形状为 (batch_size,)
"""
def reasoning_model_reward(rewards):
"""
为推理模型计算格式奖励
检查响应是否符合 CoT (Chain of Thought) 格式要求:
<think>...</think><answer>...</answer>
Args:
- rewards: 基础奖励张量
Returns:
- Tensor: 添加了格式奖励后的奖励张量
"""
# 定义两种合法的格式模式 (带换行或不带换行)
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
pattern2 = r"^<think>\n.*?\n</think>\n\n<answer>\n.*?\n</answer>$"
# 检查每个响应是否匹配格式模式
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
# 格式完全正确时奖励 0.5 分
format_rewards = []
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
if match_pattern or match_pattern2:
@ -83,34 +160,44 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
rewards += torch.tensor(format_rewards, device=args.device)
def mark_num(text):
"""计算响应中特殊标记的数量奖励"""
reward = 0
# 每个特殊标记正确出现一次奖励 0.25 分, 最多 1.0 分
if text.count("<think>") == 1: reward += 0.25
if text.count("</think>") == 1: reward += 0.25
if text.count("<answer>") == 1: reward += 0.25
if text.count("</answer>") == 1: reward += 0.25
return reward
# 统计所有响应的标记奖励
mark_rewards = [mark_num(response) for response in responses]
rewards += torch.tensor(mark_rewards, device=args.device)
return rewards
# 初始化奖励值为零
rewards = torch.zeros(len(responses), device=args.device)
# 如果启用了推理模式, 先计算格式奖励
if args.reasoning == 1:
rewards = reasoning_model_reward(rewards)
# 使用奖励模型评估响应质量 (不计算梯度)
with torch.no_grad():
reward_model_scores = []
scale = 3.0
for i, (prompt, response) in enumerate(zip(prompts, responses)):
# 从 prompt 中提取对话历史
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
matches = re.findall(pattern, prompt, re.DOTALL)
messages = [{"role": role, "content": content.strip()} for role, content in matches]
# 将当前响应添加到对话历史中
tmp_chat = messages + [{"role": "assistant", "content": response}]
# 获取奖励模型评分并裁剪到合理范围
score = reward_model.get_score(reward_tokenizer, tmp_chat)
score = max(min(score, scale), -scale)
# 如果是推理模型, 额外评估 <answer> 部分的内容质量
if args.reasoning == 1:
answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
if answer_match:
@ -118,10 +205,12 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
answer_score = max(min(answer_score, scale), -scale)
# 综合评分: 整体响应占 40%, 答案内容占 60%
score = score * 0.4 + answer_score * 0.6
reward_model_scores.append(score)
# 将奖励模型评分加入总奖励
reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
rewards += reward_model_scores
@ -129,14 +218,32 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, value_tracker, start_step=0, wandb=None):
"""
SPO 训练的单 epoch 循环
Args:
- epoch: 当前训练轮次
- loader: 数据加载器
- iters: 当前 epoch 的总步数
- ref_model: 参考模型 (Policy 模型的旧版本)
- reward_model: 奖励模型
- reward_tokenizer: 奖励模型的分词器
- value_tracker: 价值追踪器, 用于计算基线和更新
- start_step: 起始步数 (用于断点续训)
- wandb: Weights & Biases 日志对象 (可选)
"""
for step, batch in enumerate(loader, start=start_step + 1):
# 从批次中获取提示文本
prompts = batch['prompt'] # list[str], length B
# 对提示进行分词和填充, 左填充以保留序列末尾的重要信息
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
# 裁剪提示长度, 只保留最后 max_seq_len 个 token
if args.max_seq_len:
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
# 使用 Policy 模型生成响应 (不计算梯度)
with torch.no_grad():
# DDP 模型需要使用 .module 访问 generate 方法
model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
@ -144,53 +251,82 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni
**prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
num_return_sequences=1, pad_token_id=tokenizer.pad_token_id) # [B, P+R]
# 提取生成的响应部分 (去掉提示部分)
completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] # [B, R]
def get_per_token_logps(mdl, input_ids, n_keep):
"""
计算每个 token 的对数概率
Args:
- mdl: 模型对象
- input_ids: 输入 token ID 张量
- n_keep: 需要计算对数概率的最后 n_keep token
Returns:
- Tensor: 形状为 (batch_size, n_keep) 的对数概率张量
"""
# 分离输入以避免梯度计算 (仅在推理模式下)
input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
# 前向传播获取 logits, 去掉最后一个位置 (因为没有下一个 token 作为标签)
logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]
per_token_logps = []
# 对每个样本, 计算其在序列中每个 token 的对数概率
for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row
# 使用 gather 获取对应 token 的对数概率
per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
return torch.stack(per_token_logps)
# 计算 Policy 模型的对数概率和 MoE 辅助损失 (如果有)
with autocast_ctx:
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B, R]
res = model(outputs) if lm_config.use_moe else None
aux_loss = res.aux_loss if res is not None else torch.tensor(0.0, device=args.device)
# 计算参考模型的对数概率 (不计算梯度)
with torch.no_grad():
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B, R]
# 将 token ID 解码为文本响应
completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True) # list[str], length B
# 计算奖励值
rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device) # [B]
# 获取基线值 (用于优势估计)
baselines = value_tracker.get_baselines(len(prompts)).to(args.device) # [B]
scale = 3.0
# Un-normalize baselines to be in the same scale as raw rewards [-3, 3]
# 将归一化的基线值反归一化到原始奖励范围 [-3, 3]
unnormalized_baselines = baselines * (2 * scale) - scale # [B]
# 计算优势函数: 奖励 - 基线
advantages = rewards - unnormalized_baselines # [B]
# 直接使用 baseline 提供的优势估计,只做裁剪防止梯度爆炸。不再做 batch 内归一化,因为 baseline 已经提供了跨 batch 的稳定基线
# 直接使用 baseline 提供的优势估计, 只做裁剪防止梯度爆炸
advantages = advantages.clamp(-5.0, 5.0)
# 构建 completion mask, 标记有效响应序列的位置
is_eos = completion_ids == tokenizer.eos_token_id # [B, R]
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device) # [B]
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int() # [B, R]
# 计算 KL 散度 (Policy 与参考模型之间的差异)
kl_div = ref_per_token_logps - per_token_logps # [B, R]
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B, R]
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B, R], 使用反向 KL 估计
# 计算每个 token 的损失: 策略梯度 + KL 惩罚
per_token_loss = -per_token_logps * advantages.unsqueeze(1) + args.beta * per_token_kl # [B, R]
# 对响应长度取平均, 再对 batch 取平均得到策略损失
policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# 总损失包含策略损失和辅助损失 (MoE 负载均衡损失), 除以累积步数用于梯度累积
loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar
loss.backward()
# 更新价值追踪器, 获取当前的 rho 值
response_masks = completion_mask.float() # [B, R]
rho = value_tracker.update(rewards, per_token_logps.detach(), response_masks)
# 梯度累积: 每 accumulation_steps 步更新一次模型参数
if (step + 1) % args.accumulation_steps == 0:
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
@ -198,6 +334,7 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni
scheduler.step()
optimizer.zero_grad()
# 打印日志并记录到 wandb
if step % args.log_interval == 0 or step == iters:
policy_loss_val = loss.item() * args.accumulation_steps
current_aux_loss = aux_loss.item()
@ -224,19 +361,23 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni
"learning_rate": current_lr
})
# 定期保存模型检查点
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
# 获取模型状态字典并转为半精度, 保存到 CPU
state_dict = raw_model.state_dict()
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
# 同时保存完整检查点 (包含优化器状态等)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
model.train()
del state_dict
# 清理本轮迭代的中间变量以释放内存
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
del completions, rewards, advantages, completion_mask, baselines, response_masks
@ -257,36 +398,43 @@ if __name__ == "__main__":
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构 (0=否, 1=是)")
parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数")
parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型0=普通模型1=推理模型)')
parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型 (0=普通模型, 1=推理模型)')
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否, 1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-SPO", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速 (0=否, 1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
# ========== 1.初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
# ========== 2.配置目录、模型参数、检查 ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
lm_config = MiniMindConfig(
hidden_size=args.hidden_size,
num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len + args.max_gen_len,
use_moe=bool(args.use_moe)
)
# 如果启用续训, 检查是否存在检查点
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
# ========== 3.设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# https://docs.pytorch.ac.cn/docs/stable/amp
# 新版废弃: autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=dtype)
# ========== 4. 配wandb ==========
# ========== 4.配置 wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
@ -295,35 +443,37 @@ if __name__ == "__main__":
wandb_run_name = f"MiniMind-SPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 初始化模型Policy, Ref, RewardValue Tracker、数据 ==========
# ========== 5.初始化模型(Policy, Ref, Reward)和 Value Tracker、数据 ==========
base_weight = "reason" if args.reasoning == 1 else "full_sft"
# Policy模型
# Policy 模型 (要训练的模型)
model, tokenizer = init_model(lm_config, base_weight, device=args.device)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
# Reference模型
# Reference 模型 (固定参数, 用于计算 KL 散度)
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
ref_model = ref_model.eval().requires_grad_(False)
# Reward模型
# Reward 模型 (固定参数, 用于评估响应质量)
reward_model = AutoModel.from_pretrained(
args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True
)
reward_model = reward_model.to(args.device).eval().requires_grad_(False)
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
# Value Tracker
# Value Tracker (自适应基线估计器)
value_tracker = AutoAdaptiveValueTracker(rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96)
# 加载 RLAIF 数据集
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
# 计算训练步数和学习率调度器
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
iters = len(loader_for_count)
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
# ========== 6. 从ckp恢复状态 ==========
# ========== 6.从检查点恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
@ -332,15 +482,19 @@ if __name__ == "__main__":
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. DDP包模型 ==========
# ========== 7.DDP 模型 ==========
if dist.is_initialized():
# 忽略位置编码相关的缓冲区, 避免 DDP 同步
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
# ========== 8.开始训练 ==========
for epoch in range(start_epoch, args.epochs):
# 分布式采样器设置当前 epoch, 确保不同进程使用不同的数据
train_sampler and train_sampler.set_epoch(epoch)
# 为每个 epoch 设置不同的随机种子, 保证数据顺序不同
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
# 如果是续训的第一轮, 跳过已经训练过的步数
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
@ -350,5 +504,5 @@ if __name__ == "__main__":
else:
spo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, value_tracker, 0, wandb)
# ========== 9. 清理分布进程 ==========
# ========== 9.清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -1,40 +1,88 @@
# 注不建议再重复训练tokenizer“词典”MiniMind已自带此脚本仅供学习和参考。基于不同词典训练的模型将导致输出完全不统一降低社区的模型复用性
# Note: It is not recommended to re-train the tokenizer. MiniMind already includes one. This script is for learning and reference only. Training models with different tokenizers will lead to inconsistent outputs and reduce model reusability in the community.
# 注: 不建议再重复训练 tokenizer("词典"), MiniMind 已自带, 此脚本仅供学习和参考
# 基于不同词典训练的模型将导致输出完全不统一, 降低社区的模型复用性
# Note: It is not recommended to re-train the tokenizer. MiniMind already includes one. This script is for learning and reference only.
# Training models with different tokenizers will lead to inconsistent outputs and reduce model reusability in the community.
import os
import json
from tokenizers import decoders, models, pre_tokenizers, trainers, Tokenizer
# 训练数据文件路径
DATA_PATH = '../dataset/pretrain_hq.jsonl'
# 训练好的 tokenizer 保存目录
TOKENIZER_DIR = '../model_learn_tokenizer/'
# 词表大小, 控制 tokenizer 的词汇量
VOCAB_SIZE = 6400
def get_texts(data_path):
"""
JSONL 文件中逐行读取文本数据
这是一个生成器函数, 用于逐行读取数据文件并提取文本内容
最多读取前 10000 行用于实验性训练
Args:
- data_path: JSONL 格式的数据文件路径
Yields:
- str: 每行数据中 'text' 字段的内容
"""
with open(data_path, 'r', encoding='utf-8') as f:
for i, line in enumerate(f):
if i >= 10000: break # 实验性可只用前10000行测试
# 实验性, 可只用前 10000 行测试
if i >= 10000:
break
data = json.loads(line)
yield data['text']
def train_tokenizer(data_path, tokenizer_dir, vocab_size):
"""
训练 BPE(Byte Pair Encoding) tokenizer
使用 Hugging Face tokenizers 库训练自定义的 BPE tokenizer
包括配置预分词器训练器保存词表和配置文件等步骤
Args:
- data_path: 训练数据文件路径
- tokenizer_dir: tokenizer 保存目录
- vocab_size: 目标词表大小
"""
# 创建 BPE 模型的 tokenizer 实例
tokenizer = Tokenizer(models.BPE())
# 设置 ByteLevel 预分词器, 不添加前缀空格
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
# 配置 BPE 训练器
trainer = trainers.BpeTrainer(
vocab_size=vocab_size,
# 定义特殊 token: 文本结束符、指令开始符、指令结束符
special_tokens=["<|endoftext|>", "<|im_start|>", "<|im_end|>"],
show_progress=True,
# 使用 ByteLevel 预分词器的初始字母表
initial_alphabet=pre_tokenizers.ByteLevel.alphabet()
)
# 获取训练文本数据
texts = get_texts(data_path)
# 使用迭代器训练 tokenizer
tokenizer.train_from_iterator(texts, trainer=trainer)
# 设置 ByteLevel 解码器
tokenizer.decoder = decoders.ByteLevel()
# 验证特殊 token 的 ID 是否符合预期
assert tokenizer.token_to_id("<|endoftext|>") == 0
assert tokenizer.token_to_id("<|im_start|>") == 1
assert tokenizer.token_to_id("<|im_end|>") == 2
# 创建保存目录 (如果不存在)
os.makedirs(tokenizer_dir, exist_ok=True)
# 保存完整的 tokenizer 配置
tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
# 保存 BPE 模型文件 (vocab.json 和 merges.txt)
tokenizer.model.save(tokenizer_dir)
# 构建 Hugging Face 格式的 tokenizer 配置
config = {
"add_bos_token": False,
"add_eos_token": False,
@ -79,48 +127,68 @@ def train_tokenizer(data_path, tokenizer_dir, vocab_size):
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' -%}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else -%}\n {{- '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}"
}
# 保存 tokenizer 配置文件
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=4)
print("Tokenizer training completed.")
def eval_tokenizer(tokenizer_dir):
"""
评估和测试训练好的 tokenizer
加载训练好的 tokenizer, 使用对话模板生成 prompt,
并测试编码解码的一致性和流式解码效果
Args:
- tokenizer_dir: tokenizer 模型目录路径
"""
from transformers import AutoTokenizer
# 加载训练好的 tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
# 构造测试对话消息
messages = [
{"role": "system", "content": "你是一个优秀的聊天机器人,总是给我正确的回应!"},
{"role": "user", "content": '你来自哪里?'},
{"role": "system", "content": "你是一个优秀的聊天机器人, 总是给我正确的回应!"},
{"role": "user", "content": '你来自哪里?'},
{"role": "assistant", "content": '我来自地球'}
]
# 应用对话模板生成 prompt
new_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False
)
print('-'*100)
print('-' * 100)
print(new_prompt)
print('-'*100)
print('tokenizer词表长度', len(tokenizer))
print('-' * 100)
# 打印 tokenizer 词表长度
print('tokenizer 词表长度:', len(tokenizer))
# 对 prompt 进行编码
model_inputs = tokenizer(new_prompt)
print('encoder长度', len(model_inputs['input_ids']))
print('encoder 长度:', len(model_inputs['input_ids']))
# 解码并验证一致性
response = tokenizer.decode(model_inputs['input_ids'], skip_special_tokens=False)
print('decoder一致性:', response == new_prompt, "\n")
print('decoder 一致性:', response == new_prompt, "\n")
print('-'*100)
print('流式解码(字节缓冲)测试:')
print('-' * 100)
print('流式解码(字节缓冲)测试:')
input_ids = model_inputs['input_ids']
token_cache = []
for tid in input_ids:
token_cache.append(tid)
# 尝试解码当前缓存的 token
current_decode = tokenizer.decode(token_cache)
# 当解码结果有效且不包含 Unicode 替换字符时输出
if current_decode and '\ufffd' not in current_decode:
display_ids = token_cache[0] if len(token_cache) == 1 else token_cache
raw_tokens = [tokenizer.convert_ids_to_tokens(int(t)) for t in (token_cache if isinstance(token_cache, list) else [token_cache])]
print(f'Token ID: {str(display_ids):15} -> Raw: {str(raw_tokens):20} -> Decode Str: {current_decode}')
token_cache = []
if __name__ == '__main__':
train_tokenizer(DATA_PATH, TOKENIZER_DIR, VOCAB_SIZE)
eval_tokenizer(TOKENIZER_DIR)

View File

@ -1,10 +1,11 @@
"""
训练工具函数集合
"""
import os
import sys
# 设置包名
__package__ = "trainer"
# 将项目根目录添加到系统路径, 确保能够导入同级目录的模块
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import random
import math
import numpy as np
@ -16,64 +17,194 @@ from transformers import AutoTokenizer
from model.model_minimind import MiniMindForCausalLM
def get_model_params(model, config):
"""
计算并打印模型参数量统计信息
Args:
- model: 待统计的 PyTorch 模型
- config: 模型配置对象, 包含 use_moe, n_routed_experts, num_experts, num_experts_per_tok, n_shared_experts 等属性
Returns:
- 无返回值, 会在主进程打印模型参数量信息
"""
# 计算模型总参数量, 单位为百万 (M)
total = sum(p.numel() for p in model.parameters()) / 1e6
# 从配置中获取 MoE 相关的专家数量配置
# n_routed: 路由专家总数, n_active: 每个 token 激活的专家数, n_shared: 共享专家数
n_routed = getattr(config, 'n_routed_experts', getattr(config, 'num_experts', 0))
n_active = getattr(config, 'num_experts_per_tok', 0)
n_shared = getattr(config, 'n_shared_experts', 0)
# 计算路由专家 (routed experts) 的参数量, 单位为百万 (M)
# 通过匹配参数名中包含 'mlp.experts.0.' 的参数来统计
expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.experts.0.' in n) / 1e6
# 计算共享专家 (shared experts) 的参数量, 单位为百万 (M)
# 通过匹配参数名中包含 'mlp.shared_experts.0.' 的参数来统计
shared_expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.shared_experts.0.' in n) / 1e6
# 计算基础参数量 (非专家部分的参数)
base = total - (expert * n_routed) - (shared_expert * n_shared)
# 计算激活的参数量 (基础参数 + 激活的专家参数)
active = base + (expert * n_active) + (shared_expert * n_shared)
if active < total: Logger(f'Model Params: {total:.2f}M-A{active:.2f}M')
else: Logger(f'Model Params: {total:.2f}M')
# 如果存在未被激活的专家参数, 打印总参数量和激活参数量, 否则只打印总参数量
if active < total:
Logger(f'Model Params: {total:.2f}M-A{active:.2f}M')
else:
Logger(f'Model Params: {total:.2f}M')
def is_main_process():
"""
判断当前进程是否为主进程
在分布式训练环境中, 只有 rank 0 的进程是主进程
如果分布式环境未初始化, 也认为当前是主进程
Returns:
- bool: 如果是主进程返回 True, 否则返回 False
"""
return not dist.is_initialized() or dist.get_rank() == 0
def Logger(content):
"""
线程安全的日志打印函数
只能在主进程 (rank 0) 打印日志, 避免分布式训练时日志重复输出
Args:
- content: 要打印的内容, 可以是字符串或任何可打印对象
"""
if is_main_process():
print(content)
def get_lr(current_step, total_steps, lr):
"""
计算当前学习率, 使用余弦退火策略 (Cosine Annealing)
学习率变化曲线: 0.55*lr 开始, 经过余弦变化, 最终回到 0.1*lr
公式: lr * (0.1 + 0.45 * (1 + cos(pi * current_step / total_steps)))
Args:
- current_step: 当前训练步数 ( 0 开始)
- total_steps: 总训练步数
- lr: 基础学习率
Returns:
- float: 当前步的学习率值
"""
return lr*(0.1 + 0.45*(1 + math.cos(math.pi * current_step / total_steps)))
def init_distributed_mode():
if int(os.environ.get("RANK", -1)) == -1:
return 0 # 非DDP模式
"""
初始化分布式训练模式
检查环境变量 RANK 是否存在来判断是否使用分布式训练
如果未使用分布式训练, 返回 0 表示非 DDP 模式
如果使用分布式训练, 初始化 NCCL 后端, 设置本地 GPU 设备
Returns:
- int: 本地 rank 编号, DDP 模式下返回 0
"""
# 检查是否需要初始化分布式训练
# 如果环境变量 RANK 不存在或为 -1, 则表示不使用分布式训练
if int(os.environ.get("RANK", -1)) == -1:
return 0 # 非 DDP 模式
# 初始化分布式训练进程组, 使用 NCCL 后端 (适用于 GPU 通信)
dist.init_process_group(backend="nccl")
# 从环境变量获取本地 rank, 并设置当前进程使用的 GPU 设备
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
return local_rank
def setup_seed(seed: int):
"""
设置所有随机数种子, 确保实验可复现
- 设置 Python randomnumpytorch 的随机种子
- 以及 CUDA 相关 的随机种子和确定性设置
Args:
- seed: 随机数种子值, 建议使用整数
"""
# 设置 Python 内置随机数模块的种子
random.seed(seed)
# 设置 NumPy 的随机数种子
np.random.seed(seed)
# 设置 PyTorch CPU 计算的随机种子
torch.manual_seed(seed)
# 设置 PyTorch 单个 GPU 的随机种子
torch.cuda.manual_seed(seed)
# 设置 PyTorch 所有 GPU 的随机种子 (多卡训练时需要)
torch.cuda.manual_seed_all(seed)
# 设置 cuDNN 为确定性模式, 禁用优化以确保可复现
# 这会降低训练速度, 但能确保结果可复现
torch.backends.cudnn.deterministic = True
# 禁用 cuDNN 的自动调优功能, 避免因算法选择导致的不可复现
torch.backends.cudnn.benchmark = False
def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoch=0, step=0, wandb=None, save_dir='../checkpoints', **kwargs):
"""
模型检查点保存与加载函数
支持两种模式:
1. 保存模式: model 参数不为 None , 保存模型权重优化器状态训练进度等信息
2. 加载模式: model 参数为 None , 从检查点文件恢复训练状态
Args:
- lm_config: 模型配置对象, 包含 use_moe, hidden_size 等属性
- weight: 检查点名称前缀, 默认为 'full_sft'
- model: 待保存的模型对象, None 时表示加载模式
- optimizer: 优化器对象, 用于保存优化器状态
- epoch: 当前训练轮次
- step: 当前训练步数
- wandb: Weights & Biases 日志对象, 用于获取运行 ID
- save_dir: 检查点保存目录
- **kwargs: 额外的可保存对象 ( lr_scheduler ), 需要具备 state_dict 方法
Returns:
- dict None: 加载模式下返回检查点数据字典, 保存模式下返回 None
"""
# 确保保存目录存在
os.makedirs(save_dir, exist_ok=True)
# 构建检查点文件名
# 根据是否使用 MoE 模型添加不同的后缀
moe_path = '_moe' if lm_config.use_moe else ''
# 完整模型权重文件路径
ckp_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}.pth'
# 恢复训练所需的完整检查点文件路径 (包含优化器状态等)
resume_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}_resume.pth'
# ===== 保存模式 =====
if model is not None:
# 获取原始模型 (如果是 DDP 包装的模型, 需要解包)
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
# 处理 PEFT/Lora 等场景下的模型 (使用 _orig_mod 属性获取原始模型)
raw_model = getattr(raw_model, '_orig_mod', raw_model)
# 获取模型状态字典, 并将权重转换为半精度 (FP16) 后转移到 CPU
state_dict = raw_model.state_dict()
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
# 使用临时文件避免保存过程中断导致文件损坏
ckp_tmp = ckp_path + '.tmp'
torch.save(state_dict, ckp_tmp)
os.replace(ckp_tmp, ckp_path)
# 获取 Weights & Biases 的运行 ID
wandb_id = None
if wandb:
if hasattr(wandb, 'get_run'):
@ -82,16 +213,20 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
else:
wandb_id = getattr(wandb, 'id', None)
# 构建恢复训练所需的完整数据字典
resume_data = {
'model': state_dict,
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'step': step,
'world_size': dist.get_world_size() if dist.is_initialized() else 1,
'wandb_id': wandb_id
'model': state_dict, # 模型权重
'optimizer': optimizer.state_dict(), # 优化器状态
'epoch': epoch, # 当前训练轮次
'step': step, # 当前训练步数
'world_size': dist.get_world_size() if dist.is_initialized() else 1, # 分布式训练 world size
'wandb_id': wandb_id # W&B 运行 ID
}
# 处理额外的可保存对象 (如学习率调度器等)
for key, value in kwargs.items():
if value is not None:
# 如果对象是模型 (可能是 DDP 包装的), 先解包再获取 state_dict
if hasattr(value, 'state_dict'):
raw_value = value.module if isinstance(value, DistributedDataParallel) else value
raw_value = getattr(raw_value, '_orig_mod', raw_value)
@ -99,59 +234,142 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
else:
resume_data[key] = value
# 保存完整检查点到临时文件, 然后原子替换
resume_tmp = resume_path + '.tmp'
torch.save(resume_data, resume_tmp)
os.replace(resume_tmp, resume_path)
# 清理内存
del state_dict, resume_data
torch.cuda.empty_cache()
else: # 加载模式
# ===== 加载模式 =====
else:
# 如果检查点文件存在, 加载恢复数据
if os.path.exists(resume_path):
ckp_data = torch.load(resume_path, map_location='cpu')
# 处理 GPU 数量变化的情况, 自动调整 step 值
saved_ws = ckp_data.get('world_size', 1)
current_ws = dist.get_world_size() if dist.is_initialized() else 1
if saved_ws != current_ws:
# 当 GPU 数量变化时, 按比例调整 step
ckp_data['step'] = ckp_data['step'] * saved_ws // current_ws
Logger(f'GPU数量变化({saved_ws}{current_ws})step已自动转换为{ckp_data["step"]}')
Logger(f'GPU 数量变化 ({saved_ws}{current_ws}), step 已自动转换为 {ckp_data["step"]}')
return ckp_data
# 如果检查点文件不存在, 返回 None
return None
def init_model(lm_config, from_weight='pretrain', tokenizer_path='../model', save_dir='../out', device='cuda'):
"""
初始化模型和分词器
加载 Hugging Face 格式的分词器, 创建 MiniMind 模型实例
可选择从预训练权重加载模型参数
Args:
- lm_config: 模型配置对象
- from_weight: 预训练权重名称前缀, 默认为 'pretrain', 设置为 'none' 表示不加载权重
- tokenizer_path: 分词器模型目录路径
- save_dir: 权重文件保存目录
- device: 模型运行设备, 默认为 'cuda'
Returns:
- tuple: (model, tokenizer) 模型和分词器对象
"""
# 从预训练目录加载分词器
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
# 创建 MiniMind 因果语言模型实例
model = MiniMindForCausalLM(lm_config)
if from_weight!= 'none':
# 如果指定了加载权重, 从检查点文件加载模型参数
if from_weight != 'none':
# 根据是否使用 MoE 构建权重文件名
moe_suffix = '_moe' if lm_config.use_moe else ''
weight_path = f'{save_dir}/{from_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
# 加载权重到 CPU, 然后加载到模型
weights = torch.load(weight_path, map_location=device)
model.load_state_dict(weights, strict=False)
# 打印模型参数量统计信息
get_model_params(model, lm_config)
# 打印可训练参数量 (需要计算梯度)
Logger(f'Trainable Params: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f}M')
# 将模型移动到指定设备并返回
return model.to(device), tokenizer
class SkipBatchSampler(Sampler):
"""
可跳过指定数量批次的采样器
继承自 PyTorch Sampler, 在训练过程中跳过指定数量的初始批次
常用于从检查点恢复训练时, 跳过已经训练过的批次
"""
def __init__(self, sampler, batch_size, skip_batches=0):
"""
初始化 SkipBatchSampler
Args:
- sampler: 基础采样器, 提供数据索引序列
- batch_size: 每个批次的样本数量
- skip_batches: 要跳过的批次数量, 默认为 0 (不跳过)
"""
self.sampler = sampler
self.batch_size = batch_size
self.skip_batches = skip_batches
def __iter__(self):
"""
生成批次索引迭代器
遍历基础采样器的索引, 组装成批次
跳过指定数量的初始批次
Yields:
- list: 批次索引列表, 每个列表包含 batch_size 个样本索引
"""
batch = []
skipped = 0
# 遍历基础采样器的所有索引
for idx in self.sampler:
batch.append(idx)
# 当批次满时
if len(batch) == self.batch_size:
# 如果还有需要跳过的批次, 跳过当前批次
if skipped < self.skip_batches:
skipped += 1
batch = []
continue
# 产出当前批次
yield batch
batch = []
# 处理最后一个不完整的批次
if len(batch) > 0 and skipped >= self.skip_batches:
yield batch
def __len__(self):
"""
返回有效批次的总数量
Returns:
- int: 跳过 skip_batches 后的可用批次数量
"""
# 计算总批次数, 向上取整
total_batches = (len(self.sampler) + self.batch_size - 1) // self.batch_size
# 返回跳过 skip_batches 后的批次数量, 不能为负数
return max(0, total_batches - self.skip_batches)