mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
[fix] transformers-5.x
This commit is contained in:
parent
5704766352
commit
1718e9a44d
@ -212,6 +212,10 @@ class MiniMindModel(nn.Module):
|
||||
past_key_values = past_key_values or [None] * len(self.layers)
|
||||
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
|
||||
hidden_states = self.dropout(self.embed_tokens(input_ids))
|
||||
# Recompute RoPE buffers lost during meta-device init (transformers>=5.x)
|
||||
if self.freqs_cos[0, 0] == 0:
|
||||
freqs_cos, freqs_sin = precompute_freqs_cis(dim=self.config.head_dim, end=self.config.max_position_embeddings, rope_base=self.config.rope_theta, rope_scaling=self.config.rope_scaling)
|
||||
self.freqs_cos, self.freqs_sin = freqs_cos.to(hidden_states.device), freqs_sin.to(hidden_states.device)
|
||||
position_embeddings = (self.freqs_cos[start_pos:start_pos + seq_length], self.freqs_sin[start_pos:start_pos + seq_length])
|
||||
presents = []
|
||||
for layer, past_key_value in zip(self.layers, past_key_values):
|
||||
@ -229,13 +233,15 @@ class MiniMindModel(nn.Module):
|
||||
|
||||
class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
|
||||
config_class = MiniMindConfig
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
def __init__(self, config: MiniMindConfig = None):
|
||||
self.config = config or MiniMindConfig()
|
||||
super().__init__(self.config)
|
||||
self.model = MiniMindModel(self.config)
|
||||
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
||||
if self.config.tie_word_embeddings: self.model.embed_tokens.weight = self.lm_head.weight
|
||||
|
||||
self.post_init()
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, logits_to_keep=0, labels=None, **kwargs):
|
||||
hidden_states, past_key_values, aux_loss = self.model(input_ids, attention_mask, past_key_values, use_cache, **kwargs)
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
|
||||
Loading…
Reference in New Issue
Block a user