[feat] release memory

This commit is contained in:
jingyaogong
2025-11-27 19:39:49 +08:00
parent d7f4f4eab8
commit 6b86ea399a
8 changed files with 23 additions and 8 deletions
+5 -8
View File
@@ -439,7 +439,6 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
self.model = MiniMindModel(self.config)
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.model.embed_tokens.weight = self.lm_head.weight
self.OUT = CausalLMOutputWithPast()
def forward(self,
input_ids: Optional[torch.Tensor] = None,
@@ -448,7 +447,7 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
use_cache: bool = False,
logits_to_keep: Union[int, torch.Tensor] = 0,
**args):
h, past_kvs, aux_loss = self.model(
hidden_states, past_key_values, aux_loss = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
@@ -456,9 +455,7 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
**args
)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(h[:, slice_indices, :])
self.OUT.__setitem__('last_hidden_state', h)
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('aux_loss', aux_loss)
self.OUT.__setitem__('past_key_values', past_kvs)
return self.OUT
logits = self.lm_head(hidden_states[:, slice_indices, :])
output = CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
output.aux_loss = aux_loss
return output