mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
[feat] release memory
This commit is contained in:
@@ -439,7 +439,6 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
|
||||
self.model = MiniMindModel(self.config)
|
||||
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
||||
self.model.embed_tokens.weight = self.lm_head.weight
|
||||
self.OUT = CausalLMOutputWithPast()
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
@@ -448,7 +447,7 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
|
||||
use_cache: bool = False,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**args):
|
||||
h, past_kvs, aux_loss = self.model(
|
||||
hidden_states, past_key_values, aux_loss = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
@@ -456,9 +455,7 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
|
||||
**args
|
||||
)
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(h[:, slice_indices, :])
|
||||
self.OUT.__setitem__('last_hidden_state', h)
|
||||
self.OUT.__setitem__('logits', logits)
|
||||
self.OUT.__setitem__('aux_loss', aux_loss)
|
||||
self.OUT.__setitem__('past_key_values', past_kvs)
|
||||
return self.OUT
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
output = CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
|
||||
output.aux_loss = aux_loss
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user