diff --git a/model/model_minimind.py b/model/model_minimind.py index e6b6096..259f0af 100755 --- a/model/model_minimind.py +++ b/model/model_minimind.py @@ -439,7 +439,6 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin): self.model = MiniMindModel(self.config) self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) self.model.embed_tokens.weight = self.lm_head.weight - self.OUT = CausalLMOutputWithPast() def forward(self, input_ids: Optional[torch.Tensor] = None, @@ -448,7 +447,7 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin): use_cache: bool = False, logits_to_keep: Union[int, torch.Tensor] = 0, **args): - h, past_kvs, aux_loss = self.model( + hidden_states, past_key_values, aux_loss = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, @@ -456,9 +455,7 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin): **args ) slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(h[:, slice_indices, :]) - self.OUT.__setitem__('last_hidden_state', h) - self.OUT.__setitem__('logits', logits) - self.OUT.__setitem__('aux_loss', aux_loss) - self.OUT.__setitem__('past_key_values', past_kvs) - return self.OUT + logits = self.lm_head(hidden_states[:, slice_indices, :]) + output = CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values, hidden_states=hidden_states) + output.aux_loss = aux_loss + return output diff --git a/trainer/train_distill_reason.py b/trainer/train_distill_reason.py index ee1ceba..be04798 100644 --- a/trainer/train_distill_reason.py +++ b/trainer/train_distill_reason.py @@ -90,6 +90,8 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb= lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') model.train() + del X, Y, loss_mask, res, loss + if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind Reasoning Distillation") diff --git a/trainer/train_distillation.py b/trainer/train_distillation.py index 1105e8d..8f5e3f9 100644 --- a/trainer/train_distillation.py +++ b/trainer/train_distillation.py @@ -128,6 +128,8 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') model.train() + del X, Y, loss_mask, res, student_logits, teacher_logits, ce_loss, distill_loss, loss + if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind Knowledge Distillation") diff --git a/trainer/train_dpo.py b/trainer/train_dpo.py index e6bfc40..afa51f4 100644 --- a/trainer/train_dpo.py +++ b/trainer/train_dpo.py @@ -115,6 +115,9 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb= lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') model.train() + del x_chosen, x_rejected, y_chosen, y_rejected, mask_chosen, mask_rejected, x, y, mask + del ref_outputs, ref_logits, ref_log_probs, outputs, logits, policy_log_probs, loss + if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind DPO (Direct Preference Optimization)") diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index 316ed74..4d0a181 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -78,6 +78,8 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler) model.train() + del X, Y, loss_mask, res, loss + if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind Full SFT") diff --git a/trainer/train_lora.py b/trainer/train_lora.py index c98537d..4e2d30f 100644 --- a/trainer/train_lora.py +++ b/trainer/train_lora.py @@ -73,6 +73,8 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None): lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') model.train() + del X, Y, loss_mask, res, loss + if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind LoRA Fine-tuning") diff --git a/trainer/train_ppo.py b/trainer/train_ppo.py index 0780fd3..cb6d7ff 100644 --- a/trainer/train_ppo.py +++ b/trainer/train_ppo.py @@ -233,6 +233,11 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche critic_optimizer=critic_optimizer, critic_scheduler=critic_scheduler) actor_model.train() + del enc, gen_out, responses_text, rewards, full_mask, values_seq, values, advantages + del logits, labels, logp_tokens, final_mask, actor_logp, old_logits, old_logp, ref_logits, ref_logp + del kl, kl_ref, ratio, surr1, surr2, policy_loss, value_loss, loss + torch.cuda.empty_cache() + if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind PPO (Proximal Policy Optimization)") diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py index c5acd3e..a926288 100644 --- a/trainer/train_pretrain.py +++ b/trainer/train_pretrain.py @@ -77,6 +77,8 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') model.train() + del X, Y, loss_mask, res, loss + if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind Pretraining")