mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
[fix] compile unpack
This commit is contained in:
parent
81d24a4f16
commit
e119db8478
@ -43,8 +43,9 @@ def load_lora(model, path):
|
||||
|
||||
|
||||
def save_lora(model, path):
|
||||
raw_model = getattr(model, '_orig_mod', model)
|
||||
state_dict = {}
|
||||
for name, module in model.named_modules():
|
||||
for name, module in raw_model.named_modules():
|
||||
if hasattr(module, 'lora'):
|
||||
clean_name = name[7:] if name.startswith("module.") else name
|
||||
lora_state = {f'{clean_name}.lora.{k}': v for k, v in module.lora.state_dict().items()}
|
||||
|
||||
@ -121,10 +121,9 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
|
||||
model.eval()
|
||||
moe_suffix = '_moe' if lm_config_student.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config_student.hidden_size}{moe_suffix}.pth'
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
state_dict = raw_model.state_dict()
|
||||
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||||
torch.save(state_dict, ckp)
|
||||
lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
|
||||
@ -108,10 +108,9 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
|
||||
model.eval()
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
state_dict = raw_model.state_dict()
|
||||
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||||
torch.save(state_dict, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
|
||||
@ -69,10 +69,9 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
model.eval()
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
state_dict = raw_model.state_dict()
|
||||
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||||
torch.save(state_dict, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
|
||||
|
||||
@ -180,7 +180,9 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token
|
||||
model.eval()
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
state_dict = raw_model.state_dict()
|
||||
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
|
||||
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
|
||||
|
||||
@ -220,7 +220,9 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
|
||||
f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.8f}, Critic LR: {critic_lr:.8f}")
|
||||
|
||||
if (step + 1) % args.update_old_actor_freq == 0:
|
||||
state_dict = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict()
|
||||
raw_actor = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
|
||||
raw_actor = getattr(raw_actor, '_orig_mod', raw_actor)
|
||||
state_dict = raw_actor.state_dict()
|
||||
old_actor_model.load_state_dict({k: v.detach().cpu() for k, v in state_dict.items()})
|
||||
old_actor_model.to(args.device)
|
||||
|
||||
@ -228,7 +230,9 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
|
||||
actor_model.eval()
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
actor_state = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict()
|
||||
raw_actor = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
|
||||
raw_actor = getattr(raw_actor, '_orig_mod', raw_actor)
|
||||
actor_state = raw_actor.state_dict()
|
||||
torch.save({k: v.half().cpu() for k, v in actor_state.items()}, ckp)
|
||||
|
||||
# 使用 lm_checkpoint 保存完整状态(包括 critic)
|
||||
|
||||
@ -69,10 +69,9 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
model.eval()
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
state_dict = raw_model.state_dict()
|
||||
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||||
torch.save(state_dict, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
|
||||
@ -82,10 +82,9 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=
|
||||
model.eval()
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
state_dict = raw_model.state_dict()
|
||||
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||||
torch.save(state_dict, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
|
||||
@ -228,7 +228,9 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni
|
||||
model.eval()
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
state_dict = raw_model.state_dict()
|
||||
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
|
||||
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
|
||||
|
||||
@ -10,6 +10,7 @@ import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import Sampler
|
||||
from transformers import AutoTokenizer
|
||||
from model.model_minimind import MiniMindForCausalLM
|
||||
@ -66,8 +67,9 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
|
||||
resume_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}_resume.pth'
|
||||
|
||||
if model is not None:
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
state_dict = raw_model.state_dict()
|
||||
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||||
ckp_tmp = ckp_path + '.tmp'
|
||||
torch.save(state_dict, ckp_tmp)
|
||||
@ -91,10 +93,9 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
|
||||
for key, value in kwargs.items():
|
||||
if value is not None:
|
||||
if hasattr(value, 'state_dict'):
|
||||
if isinstance(value, DistributedDataParallel):
|
||||
resume_data[key] = value.module.state_dict()
|
||||
else:
|
||||
resume_data[key] = value.state_dict()
|
||||
raw_value = value.module if isinstance(value, DistributedDataParallel) else value
|
||||
raw_value = getattr(raw_value, '_orig_mod', raw_value)
|
||||
resume_data[key] = raw_value.state_dict()
|
||||
else:
|
||||
resume_data[key] = value
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user