mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-14 04:07:17 +08:00
[fix] cuda memory #559
This commit is contained in:
parent
151fdf7e76
commit
5e1447b913
@ -85,10 +85,11 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
|
||||
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||||
torch.save(state_dict, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
del X, Y, loss_mask, res, loss
|
||||
|
||||
|
||||
@ -123,10 +123,11 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
|
||||
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||||
torch.save(state_dict, ckp)
|
||||
lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
del X, Y, loss_mask, res, student_logits, teacher_logits, ce_loss, distill_loss, loss
|
||||
|
||||
|
||||
@ -110,10 +110,11 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
|
||||
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||||
torch.save(state_dict, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
del x_chosen, x_rejected, y_chosen, y_rejected, mask_chosen, mask_rejected, x, y, mask
|
||||
del ref_outputs, ref_logits, ref_log_probs, outputs, logits, policy_log_probs, loss
|
||||
|
||||
@ -72,11 +72,12 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
|
||||
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||||
torch.save(state_dict, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
|
||||
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler)
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
del X, Y, loss_mask, res, loss
|
||||
|
||||
|
||||
@ -175,10 +175,11 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
||||
torch.save({k: v.half() for k, v in state_dict.items()}, ckp)
|
||||
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
|
||||
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
|
||||
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
|
||||
|
||||
@ -224,7 +224,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
actor_state = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict()
|
||||
torch.save({k: v.half() for k, v in actor_state.items()}, ckp)
|
||||
torch.save({k: v.half().cpu() for k, v in actor_state.items()}, ckp)
|
||||
|
||||
# 使用 lm_checkpoint 保存完整状态(包括 critic)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=actor_model, optimizer=actor_optimizer,
|
||||
@ -232,6 +232,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
|
||||
scheduler=actor_scheduler, critic_model=critic_model,
|
||||
critic_optimizer=critic_optimizer, critic_scheduler=critic_scheduler)
|
||||
actor_model.train()
|
||||
del actor_state
|
||||
|
||||
del enc, gen_out, responses_text, rewards, full_mask, values_seq, values, advantages
|
||||
del logits, labels, logp_tokens, final_mask, actor_logp, old_logits, old_logp, ref_logits, ref_logp
|
||||
|
||||
@ -72,10 +72,11 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
|
||||
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||||
torch.save(state_dict, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
del X, Y, loss_mask, res, loss
|
||||
|
||||
|
||||
@ -223,10 +223,11 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
||||
torch.save({k: v.half() for k, v in state_dict.items()}, ckp)
|
||||
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
|
||||
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
|
||||
del completions, rewards, advantages, completion_mask, baselines, response_masks
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""
|
||||
训练工具函数集合
|
||||
"""
|
||||
import gc
|
||||
import os
|
||||
import random
|
||||
import math
|
||||
@ -53,8 +54,9 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
|
||||
if model is not None:
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
||||
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||||
ckp_tmp = ckp_path + '.tmp'
|
||||
torch.save({k: v.half() for k, v in state_dict.items()}, ckp_tmp)
|
||||
torch.save(state_dict, ckp_tmp)
|
||||
os.replace(ckp_tmp, ckp_path)
|
||||
wandb_id = None
|
||||
if wandb:
|
||||
@ -85,6 +87,9 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
|
||||
resume_tmp = resume_path + '.tmp'
|
||||
torch.save(resume_data, resume_tmp)
|
||||
os.replace(resume_tmp, resume_path)
|
||||
del state_dict, resume_data
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
else: # 加载模式
|
||||
if os.path.exists(resume_path):
|
||||
ckp_data = torch.load(resume_path, map_location='cpu')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user