[fix] compile unpack

This commit is contained in:
jingyaogong 2026-01-14 20:13:32 +08:00
parent 81d24a4f16
commit e119db8478
10 changed files with 36 additions and 31 deletions

View File

@ -43,8 +43,9 @@ def load_lora(model, path):
def save_lora(model, path):
raw_model = getattr(model, '_orig_mod', model)
state_dict = {}
for name, module in model.named_modules():
for name, module in raw_model.named_modules():
if hasattr(module, 'lora'):
clean_name = name[7:] if name.startswith("module.") else name
lora_state = {f'{clean_name}.lora.{k}': v for k, v in module.lora.state_dict().items()}

View File

@ -121,10 +121,9 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
model.eval()
moe_suffix = '_moe' if lm_config_student.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config_student.hidden_size}{moe_suffix}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
torch.save(state_dict, ckp)
lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')

View File

@ -108,10 +108,9 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
torch.save(state_dict, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')

View File

@ -69,10 +69,9 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
torch.save(state_dict, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,

View File

@ -180,7 +180,9 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)

View File

@ -220,7 +220,9 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.8f}, Critic LR: {critic_lr:.8f}")
if (step + 1) % args.update_old_actor_freq == 0:
state_dict = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict()
raw_actor = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
raw_actor = getattr(raw_actor, '_orig_mod', raw_actor)
state_dict = raw_actor.state_dict()
old_actor_model.load_state_dict({k: v.detach().cpu() for k, v in state_dict.items()})
old_actor_model.to(args.device)
@ -228,7 +230,9 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
actor_model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
actor_state = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict()
raw_actor = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
raw_actor = getattr(raw_actor, '_orig_mod', raw_actor)
actor_state = raw_actor.state_dict()
torch.save({k: v.half().cpu() for k, v in actor_state.items()}, ckp)
# 使用 lm_checkpoint 保存完整状态(包括 critic

View File

@ -69,10 +69,9 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
torch.save(state_dict, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')

View File

@ -82,10 +82,9 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
torch.save(state_dict, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')

View File

@ -228,7 +228,9 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)

View File

@ -10,6 +10,7 @@ import math
import numpy as np
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import Sampler
from transformers import AutoTokenizer
from model.model_minimind import MiniMindForCausalLM
@ -66,8 +67,9 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
resume_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}_resume.pth'
if model is not None:
from torch.nn.parallel import DistributedDataParallel
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
ckp_tmp = ckp_path + '.tmp'
torch.save(state_dict, ckp_tmp)
@ -91,10 +93,9 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
for key, value in kwargs.items():
if value is not None:
if hasattr(value, 'state_dict'):
if isinstance(value, DistributedDataParallel):
resume_data[key] = value.module.state_dict()
else:
resume_data[key] = value.state_dict()
raw_value = value.module if isinstance(value, DistributedDataParallel) else value
raw_value = getattr(raw_value, '_orig_mod', raw_value)
resume_data[key] = raw_value.state_dict()
else:
resume_data[key] = value