[feat] release memory

This commit is contained in:
jingyaogong 2025-11-27 19:39:49 +08:00
parent d7f4f4eab8
commit 6b86ea399a
8 changed files with 23 additions and 8 deletions

View File

@ -439,7 +439,6 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
self.model = MiniMindModel(self.config)
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.model.embed_tokens.weight = self.lm_head.weight
self.OUT = CausalLMOutputWithPast()
def forward(self,
input_ids: Optional[torch.Tensor] = None,
@ -448,7 +447,7 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
use_cache: bool = False,
logits_to_keep: Union[int, torch.Tensor] = 0,
**args):
h, past_kvs, aux_loss = self.model(
hidden_states, past_key_values, aux_loss = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
@ -456,9 +455,7 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
**args
)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(h[:, slice_indices, :])
self.OUT.__setitem__('last_hidden_state', h)
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('aux_loss', aux_loss)
self.OUT.__setitem__('past_key_values', past_kvs)
return self.OUT
logits = self.lm_head(hidden_states[:, slice_indices, :])
output = CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
output.aux_loss = aux_loss
return output

View File

@ -90,6 +90,8 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
del X, Y, loss_mask, res, loss
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Reasoning Distillation")

View File

@ -128,6 +128,8 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
del X, Y, loss_mask, res, student_logits, teacher_logits, ce_loss, distill_loss, loss
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Knowledge Distillation")

View File

@ -115,6 +115,9 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
del x_chosen, x_rejected, y_chosen, y_rejected, mask_chosen, mask_rejected, x, y, mask
del ref_outputs, ref_logits, ref_log_probs, outputs, logits, policy_log_probs, loss
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind DPO (Direct Preference Optimization)")

View File

@ -78,6 +78,8 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler)
model.train()
del X, Y, loss_mask, res, loss
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Full SFT")

View File

@ -73,6 +73,8 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
del X, Y, loss_mask, res, loss
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind LoRA Fine-tuning")

View File

@ -233,6 +233,11 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
critic_optimizer=critic_optimizer, critic_scheduler=critic_scheduler)
actor_model.train()
del enc, gen_out, responses_text, rewards, full_mask, values_seq, values, advantages
del logits, labels, logp_tokens, final_mask, actor_logp, old_logits, old_logp, ref_logits, ref_logp
del kl, kl_ref, ratio, surr1, surr2, policy_loss, value_loss, loss
torch.cuda.empty_cache()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind PPO (Proximal Policy Optimization)")

View File

@ -77,6 +77,8 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
del X, Y, loss_mask, res, loss
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Pretraining")