diff --git a/model/model_minimind.py b/model/model_minimind.py index e8e7f54..35bb66f 100755 --- a/model/model_minimind.py +++ b/model/model_minimind.py @@ -212,6 +212,10 @@ class MiniMindModel(nn.Module): past_key_values = past_key_values or [None] * len(self.layers) start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0 hidden_states = self.dropout(self.embed_tokens(input_ids)) + # Recompute RoPE buffers lost during meta-device init (transformers>=5.x) + if self.freqs_cos[0, 0] == 0: + freqs_cos, freqs_sin = precompute_freqs_cis(dim=self.config.head_dim, end=self.config.max_position_embeddings, rope_base=self.config.rope_theta, rope_scaling=self.config.rope_scaling) + self.freqs_cos, self.freqs_sin = freqs_cos.to(hidden_states.device), freqs_sin.to(hidden_states.device) position_embeddings = (self.freqs_cos[start_pos:start_pos + seq_length], self.freqs_sin[start_pos:start_pos + seq_length]) presents = [] for layer, past_key_value in zip(self.layers, past_key_values): @@ -229,13 +233,15 @@ class MiniMindModel(nn.Module): class MiniMindForCausalLM(PreTrainedModel, GenerationMixin): config_class = MiniMindConfig + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config: MiniMindConfig = None): self.config = config or MiniMindConfig() super().__init__(self.config) self.model = MiniMindModel(self.config) self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) if self.config.tie_word_embeddings: self.model.embed_tokens.weight = self.lm_head.weight - + self.post_init() + def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, logits_to_keep=0, labels=None, **kwargs): hidden_states, past_key_values, aux_loss = self.model(input_ids, attention_mask, past_key_values, use_cache, **kwargs) slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep