[fix] cuda memory #559

This commit is contained in:
jingyaogong 2025-12-01 16:17:43 +08:00
parent 151fdf7e76
commit 5e1447b913
9 changed files with 22 additions and 9 deletions

View File

@ -85,10 +85,11 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
torch.save(state_dict, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
del state_dict
del X, Y, loss_mask, res, loss

View File

@ -123,10 +123,11 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
torch.save(state_dict, ckp)
lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
del state_dict
del X, Y, loss_mask, res, student_logits, teacher_logits, ce_loss, distill_loss, loss

View File

@ -110,10 +110,11 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
torch.save(state_dict, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
del state_dict
del x_chosen, x_rejected, y_chosen, y_rejected, mask_chosen, mask_rejected, x, y, mask
del ref_outputs, ref_logits, ref_log_probs, outputs, logits, policy_log_probs, loss

View File

@ -72,11 +72,12 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
torch.save(state_dict, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler)
model.train()
del state_dict
del X, Y, loss_mask, res, loss

View File

@ -175,10 +175,11 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
torch.save({k: v.half() for k, v in state_dict.items()}, ckp)
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
model.train()
del state_dict
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask

View File

@ -224,7 +224,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
actor_state = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict()
torch.save({k: v.half() for k, v in actor_state.items()}, ckp)
torch.save({k: v.half().cpu() for k, v in actor_state.items()}, ckp)
# 使用 lm_checkpoint 保存完整状态(包括 critic
lm_checkpoint(lm_config, weight=args.save_weight, model=actor_model, optimizer=actor_optimizer,
@ -232,6 +232,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
scheduler=actor_scheduler, critic_model=critic_model,
critic_optimizer=critic_optimizer, critic_scheduler=critic_scheduler)
actor_model.train()
del actor_state
del enc, gen_out, responses_text, rewards, full_mask, values_seq, values, advantages
del logits, labels, logp_tokens, final_mask, actor_logp, old_logits, old_logp, ref_logits, ref_logp

View File

@ -72,10 +72,11 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
torch.save(state_dict, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
del state_dict
del X, Y, loss_mask, res, loss

View File

@ -223,10 +223,11 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
torch.save({k: v.half() for k, v in state_dict.items()}, ckp)
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
model.train()
del state_dict
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
del completions, rewards, advantages, completion_mask, baselines, response_masks

View File

@ -1,6 +1,7 @@
"""
训练工具函数集合
"""
import gc
import os
import random
import math
@ -53,8 +54,9 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
if model is not None:
from torch.nn.parallel import DistributedDataParallel
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
ckp_tmp = ckp_path + '.tmp'
torch.save({k: v.half() for k, v in state_dict.items()}, ckp_tmp)
torch.save(state_dict, ckp_tmp)
os.replace(ckp_tmp, ckp_path)
wandb_id = None
if wandb:
@ -85,6 +87,9 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
resume_tmp = resume_path + '.tmp'
torch.save(resume_data, resume_tmp)
os.replace(resume_tmp, resume_path)
del state_dict, resume_data
gc.collect()
torch.cuda.empty_cache()
else: # 加载模式
if os.path.exists(resume_path):
ckp_data = torch.load(resume_path, map_location='cpu')