minimind/trainer/trainer_utils.py
2026-01-07 23:08:45 +08:00

156 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
训练工具函数集合
"""
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import random
import math
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import Sampler
from transformers import AutoTokenizer
from model.model_minimind import MiniMindForCausalLM
def get_model_params(model, config):
total = sum(p.numel() for p in model.parameters()) / 1e6
n_routed = getattr(config, 'n_routed_experts', getattr(config, 'num_experts', 0))
n_active = getattr(config, 'num_experts_per_tok', 0)
n_shared = getattr(config, 'n_shared_experts', 0)
expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.experts.0.' in n) / 1e6
shared_expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.shared_experts.0.' in n) / 1e6
base = total - (expert * n_routed) - (shared_expert * n_shared)
active = base + (expert * n_active) + (shared_expert * n_shared)
if active < total: Logger(f'Model Params: {total:.2f}M-A{active:.2f}M')
else: Logger(f'Model Params: {total:.2f}M')
def is_main_process():
return not dist.is_initialized() or dist.get_rank() == 0
def Logger(content):
if is_main_process():
print(content)
def get_lr(current_step, total_steps, lr):
return lr*(0.1 + 0.45*(1 + math.cos(math.pi * current_step / total_steps)))
def init_distributed_mode():
if int(os.environ.get("RANK", -1)) == -1:
return 0 # 非DDP模式
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
return local_rank
def setup_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoch=0, step=0, wandb=None, save_dir='../checkpoints', **kwargs):
os.makedirs(save_dir, exist_ok=True)
moe_path = '_moe' if lm_config.use_moe else ''
ckp_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}.pth'
resume_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}_resume.pth'
if model is not None:
from torch.nn.parallel import DistributedDataParallel
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
ckp_tmp = ckp_path + '.tmp'
torch.save(state_dict, ckp_tmp)
os.replace(ckp_tmp, ckp_path)
wandb_id = None
if wandb:
if hasattr(wandb, 'get_run'):
run = wandb.get_run()
wandb_id = getattr(run, 'id', None) if run else None
else:
wandb_id = getattr(wandb, 'id', None)
resume_data = {
'model': state_dict,
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'step': step,
'world_size': dist.get_world_size() if dist.is_initialized() else 1,
'wandb_id': wandb_id
}
for key, value in kwargs.items():
if value is not None:
if hasattr(value, 'state_dict'):
if isinstance(value, DistributedDataParallel):
resume_data[key] = value.module.state_dict()
else:
resume_data[key] = value.state_dict()
else:
resume_data[key] = value
resume_tmp = resume_path + '.tmp'
torch.save(resume_data, resume_tmp)
os.replace(resume_tmp, resume_path)
del state_dict, resume_data
torch.cuda.empty_cache()
else: # 加载模式
if os.path.exists(resume_path):
ckp_data = torch.load(resume_path, map_location='cpu')
saved_ws = ckp_data.get('world_size', 1)
current_ws = dist.get_world_size() if dist.is_initialized() else 1
if saved_ws != current_ws:
ckp_data['step'] = ckp_data['step'] * saved_ws // current_ws
Logger(f'GPU数量变化({saved_ws}{current_ws})step已自动转换为{ckp_data["step"]}')
return ckp_data
return None
def init_model(lm_config, from_weight='pretrain', tokenizer_path='../model', save_dir='../out', device='cuda'):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
model = MiniMindForCausalLM(lm_config)
if from_weight!= 'none':
moe_suffix = '_moe' if lm_config.use_moe else ''
weight_path = f'{save_dir}/{from_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
weights = torch.load(weight_path, map_location=device)
model.load_state_dict(weights, strict=False)
get_model_params(model, lm_config)
Logger(f'Trainable Params: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f}M')
return model.to(device), tokenizer
class SkipBatchSampler(Sampler):
def __init__(self, sampler, batch_size, skip_batches=0):
self.sampler = sampler
self.batch_size = batch_size
self.skip_batches = skip_batches
def __iter__(self):
batch = []
skipped = 0
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
if skipped < self.skip_batches:
skipped += 1
batch = []
continue
yield batch
batch = []
if len(batch) > 0 and skipped >= self.skip_batches:
yield batch
def __len__(self):
total_batches = (len(self.sampler) + self.batch_size - 1) // self.batch_size
return max(0, total_batches - self.skip_batches)