mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
158 lines
5.9 KiB
Python
158 lines
5.9 KiB
Python
"""
|
||
训练工具函数集合
|
||
"""
|
||
import os
|
||
import sys
|
||
__package__ = "trainer"
|
||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||
import random
|
||
import math
|
||
import numpy as np
|
||
import torch
|
||
import torch.distributed as dist
|
||
from torch.utils.data import Sampler
|
||
from transformers import AutoTokenizer
|
||
from model.model_minimind import MiniMindForCausalLM
|
||
|
||
def get_model_params(model, config):
|
||
total = sum(p.numel() for p in model.parameters()) / 1e6
|
||
if getattr(config, 'use_moe', False):
|
||
n_routed = getattr(config, 'n_routed_experts', getattr(config, 'num_experts', 0))
|
||
n_active = getattr(config, 'num_experts_per_tok', 0)
|
||
n_shared = getattr(config, 'n_shared_experts', 0)
|
||
expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.experts.0.' in n) / 1e6
|
||
shared_expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.shared_experts.0.' in n) / 1e6
|
||
base = total - (expert * n_routed) - (shared_expert * n_shared)
|
||
active = base + (expert * n_active) + (shared_expert * n_shared)
|
||
Logger(f'Model Params: {total:.2f}M-A{active:.2f}M')
|
||
else:
|
||
Logger(f'Model Params: {total:.2f}M')
|
||
|
||
|
||
def is_main_process():
|
||
return not dist.is_initialized() or dist.get_rank() == 0
|
||
|
||
|
||
def Logger(content):
|
||
if is_main_process():
|
||
print(content)
|
||
|
||
|
||
def get_lr(current_step, total_steps, lr):
|
||
return lr*(0.1 + 0.45*(1 + math.cos(math.pi * current_step / total_steps)))
|
||
|
||
|
||
def init_distributed_mode():
|
||
if int(os.environ.get("RANK", -1)) == -1:
|
||
return 0 # 非DDP模式
|
||
|
||
dist.init_process_group(backend="nccl")
|
||
local_rank = int(os.environ["LOCAL_RANK"])
|
||
torch.cuda.set_device(local_rank)
|
||
return local_rank
|
||
|
||
|
||
def setup_seed(seed: int):
|
||
random.seed(seed)
|
||
np.random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
torch.cuda.manual_seed(seed)
|
||
torch.cuda.manual_seed_all(seed)
|
||
torch.backends.cudnn.deterministic = True
|
||
torch.backends.cudnn.benchmark = False
|
||
|
||
def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoch=0, step=0, wandb=None, save_dir='../checkpoints', **kwargs):
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
moe_path = '_moe' if lm_config.use_moe else ''
|
||
ckp_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}.pth'
|
||
resume_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}_resume.pth'
|
||
|
||
if model is not None:
|
||
from torch.nn.parallel import DistributedDataParallel
|
||
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
||
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||
ckp_tmp = ckp_path + '.tmp'
|
||
torch.save(state_dict, ckp_tmp)
|
||
os.replace(ckp_tmp, ckp_path)
|
||
wandb_id = None
|
||
if wandb:
|
||
if hasattr(wandb, 'get_run'):
|
||
run = wandb.get_run()
|
||
wandb_id = getattr(run, 'id', None) if run else None
|
||
else:
|
||
wandb_id = getattr(wandb, 'id', None)
|
||
|
||
resume_data = {
|
||
'model': state_dict,
|
||
'optimizer': optimizer.state_dict(),
|
||
'epoch': epoch,
|
||
'step': step,
|
||
'world_size': dist.get_world_size() if dist.is_initialized() else 1,
|
||
'wandb_id': wandb_id
|
||
}
|
||
for key, value in kwargs.items():
|
||
if value is not None:
|
||
if hasattr(value, 'state_dict'):
|
||
if isinstance(value, DistributedDataParallel):
|
||
resume_data[key] = value.module.state_dict()
|
||
else:
|
||
resume_data[key] = value.state_dict()
|
||
else:
|
||
resume_data[key] = value
|
||
|
||
resume_tmp = resume_path + '.tmp'
|
||
torch.save(resume_data, resume_tmp)
|
||
os.replace(resume_tmp, resume_path)
|
||
del state_dict, resume_data
|
||
torch.cuda.empty_cache()
|
||
else: # 加载模式
|
||
if os.path.exists(resume_path):
|
||
ckp_data = torch.load(resume_path, map_location='cpu')
|
||
saved_ws = ckp_data.get('world_size', 1)
|
||
current_ws = dist.get_world_size() if dist.is_initialized() else 1
|
||
if saved_ws != current_ws:
|
||
ckp_data['step'] = ckp_data['step'] * saved_ws // current_ws
|
||
Logger(f'GPU数量变化({saved_ws}→{current_ws}),step已自动转换为{ckp_data["step"]}')
|
||
return ckp_data
|
||
return None
|
||
|
||
|
||
def init_model(lm_config, from_weight='pretrain', tokenizer_path='../model', save_dir='../out', device='cuda'):
|
||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||
model = MiniMindForCausalLM(lm_config)
|
||
|
||
if from_weight!= 'none':
|
||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||
weight_path = f'{save_dir}/{from_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||
weights = torch.load(weight_path, map_location=device)
|
||
model.load_state_dict(weights, strict=False)
|
||
|
||
get_model_params(model, lm_config)
|
||
Logger(f'Trainable Params: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f}M')
|
||
return model.to(device), tokenizer
|
||
|
||
|
||
class SkipBatchSampler(Sampler):
|
||
def __init__(self, sampler, batch_size, skip_batches=0):
|
||
self.sampler = sampler
|
||
self.batch_size = batch_size
|
||
self.skip_batches = skip_batches
|
||
|
||
def __iter__(self):
|
||
batch = []
|
||
skipped = 0
|
||
for idx in self.sampler:
|
||
batch.append(idx)
|
||
if len(batch) == self.batch_size:
|
||
if skipped < self.skip_batches:
|
||
skipped += 1
|
||
batch = []
|
||
continue
|
||
yield batch
|
||
batch = []
|
||
if len(batch) > 0 and skipped >= self.skip_batches:
|
||
yield batch
|
||
|
||
def __len__(self):
|
||
total_batches = (len(self.sampler) + self.batch_size - 1) // self.batch_size
|
||
return max(0, total_batches - self.skip_batches) |