diff --git a/model/model_lora.py b/model/model_lora.py index c675f3b..e42e8d4 100644 --- a/model/model_lora.py +++ b/model/model_lora.py @@ -43,8 +43,9 @@ def load_lora(model, path): def save_lora(model, path): + raw_model = getattr(model, '_orig_mod', model) state_dict = {} - for name, module in model.named_modules(): + for name, module in raw_model.named_modules(): if hasattr(module, 'lora'): clean_name = name[7:] if name.startswith("module.") else name lora_state = {f'{clean_name}.lora.{k}': v for k, v in module.lora.state_dict().items()} diff --git a/trainer/train_distillation.py b/trainer/train_distillation.py index bdad533..abae2c1 100644 --- a/trainer/train_distillation.py +++ b/trainer/train_distillation.py @@ -121,10 +121,9 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st model.eval() moe_suffix = '_moe' if lm_config_student.use_moe else '' ckp = f'{args.save_dir}/{args.save_weight}_{lm_config_student.hidden_size}{moe_suffix}.pth' - if isinstance(model, torch.nn.parallel.DistributedDataParallel): - state_dict = model.module.state_dict() - else: - state_dict = model.state_dict() + raw_model = model.module if isinstance(model, DistributedDataParallel) else model + raw_model = getattr(raw_model, '_orig_mod', raw_model) + state_dict = raw_model.state_dict() state_dict = {k: v.half().cpu() for k, v in state_dict.items()} torch.save(state_dict, ckp) lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') diff --git a/trainer/train_dpo.py b/trainer/train_dpo.py index 4c24414..52311ca 100644 --- a/trainer/train_dpo.py +++ b/trainer/train_dpo.py @@ -108,10 +108,9 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb= model.eval() moe_suffix = '_moe' if lm_config.use_moe else '' ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' - if isinstance(model, torch.nn.parallel.DistributedDataParallel): - state_dict = model.module.state_dict() - else: - state_dict = model.state_dict() + raw_model = model.module if isinstance(model, DistributedDataParallel) else model + raw_model = getattr(raw_model, '_orig_mod', raw_model) + state_dict = raw_model.state_dict() state_dict = {k: v.half().cpu() for k, v in state_dict.items()} torch.save(state_dict, ckp) lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index a99db31..3908213 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -69,10 +69,9 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): model.eval() moe_suffix = '_moe' if lm_config.use_moe else '' ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' - if isinstance(model, torch.nn.parallel.DistributedDataParallel): - state_dict = model.module.state_dict() - else: - state_dict = model.state_dict() + raw_model = model.module if isinstance(model, DistributedDataParallel) else model + raw_model = getattr(raw_model, '_orig_mod', raw_model) + state_dict = raw_model.state_dict() state_dict = {k: v.half().cpu() for k, v in state_dict.items()} torch.save(state_dict, ckp) lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, diff --git a/trainer/train_grpo.py b/trainer/train_grpo.py index 088edcb..593780f 100755 --- a/trainer/train_grpo.py +++ b/trainer/train_grpo.py @@ -180,7 +180,9 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token model.eval() moe_suffix = '_moe' if lm_config.use_moe else '' ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' - state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict() + raw_model = model.module if isinstance(model, DistributedDataParallel) else model + raw_model = getattr(raw_model, '_orig_mod', raw_model) + state_dict = raw_model.state_dict() torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp) lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler) diff --git a/trainer/train_ppo.py b/trainer/train_ppo.py index 22bc73d..02e9e0c 100644 --- a/trainer/train_ppo.py +++ b/trainer/train_ppo.py @@ -220,7 +220,9 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.8f}, Critic LR: {critic_lr:.8f}") if (step + 1) % args.update_old_actor_freq == 0: - state_dict = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict() + raw_actor = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model + raw_actor = getattr(raw_actor, '_orig_mod', raw_actor) + state_dict = raw_actor.state_dict() old_actor_model.load_state_dict({k: v.detach().cpu() for k, v in state_dict.items()}) old_actor_model.to(args.device) @@ -228,7 +230,9 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche actor_model.eval() moe_suffix = '_moe' if lm_config.use_moe else '' ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' - actor_state = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict() + raw_actor = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model + raw_actor = getattr(raw_actor, '_orig_mod', raw_actor) + actor_state = raw_actor.state_dict() torch.save({k: v.half().cpu() for k, v in actor_state.items()}, ckp) # 使用 lm_checkpoint 保存完整状态(包括 critic) diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py index 57234a8..4ebceeb 100644 --- a/trainer/train_pretrain.py +++ b/trainer/train_pretrain.py @@ -69,10 +69,9 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): model.eval() moe_suffix = '_moe' if lm_config.use_moe else '' ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' - if isinstance(model, torch.nn.parallel.DistributedDataParallel): - state_dict = model.module.state_dict() - else: - state_dict = model.state_dict() + raw_model = model.module if isinstance(model, DistributedDataParallel) else model + raw_model = getattr(raw_model, '_orig_mod', raw_model) + state_dict = raw_model.state_dict() state_dict = {k: v.half().cpu() for k, v in state_dict.items()} torch.save(state_dict, ckp) lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') diff --git a/trainer/train_reason.py b/trainer/train_reason.py index 728ddac..fbdaf29 100644 --- a/trainer/train_reason.py +++ b/trainer/train_reason.py @@ -82,10 +82,9 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb= model.eval() moe_suffix = '_moe' if lm_config.use_moe else '' ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' - if isinstance(model, torch.nn.parallel.DistributedDataParallel): - state_dict = model.module.state_dict() - else: - state_dict = model.state_dict() + raw_model = model.module if isinstance(model, DistributedDataParallel) else model + raw_model = getattr(raw_model, '_orig_mod', raw_model) + state_dict = raw_model.state_dict() state_dict = {k: v.half().cpu() for k, v in state_dict.items()} torch.save(state_dict, ckp) lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') diff --git a/trainer/train_spo.py b/trainer/train_spo.py index 533f683..d449491 100755 --- a/trainer/train_spo.py +++ b/trainer/train_spo.py @@ -228,7 +228,9 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni model.eval() moe_suffix = '_moe' if lm_config.use_moe else '' ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' - state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict() + raw_model = model.module if isinstance(model, DistributedDataParallel) else model + raw_model = getattr(raw_model, '_orig_mod', raw_model) + state_dict = raw_model.state_dict() torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp) lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler) diff --git a/trainer/trainer_utils.py b/trainer/trainer_utils.py index 58d0549..3ec1e44 100644 --- a/trainer/trainer_utils.py +++ b/trainer/trainer_utils.py @@ -10,6 +10,7 @@ import math import numpy as np import torch import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel from torch.utils.data import Sampler from transformers import AutoTokenizer from model.model_minimind import MiniMindForCausalLM @@ -66,8 +67,9 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc resume_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}_resume.pth' if model is not None: - from torch.nn.parallel import DistributedDataParallel - state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict() + raw_model = model.module if isinstance(model, DistributedDataParallel) else model + raw_model = getattr(raw_model, '_orig_mod', raw_model) + state_dict = raw_model.state_dict() state_dict = {k: v.half().cpu() for k, v in state_dict.items()} ckp_tmp = ckp_path + '.tmp' torch.save(state_dict, ckp_tmp) @@ -91,10 +93,9 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc for key, value in kwargs.items(): if value is not None: if hasattr(value, 'state_dict'): - if isinstance(value, DistributedDataParallel): - resume_data[key] = value.module.state_dict() - else: - resume_data[key] = value.state_dict() + raw_value = value.module if isinstance(value, DistributedDataParallel) else value + raw_value = getattr(raw_value, '_orig_mod', raw_value) + resume_data[key] = raw_value.state_dict() else: resume_data[key] = value