diff --git a/model/model_minimind.py b/model/model_minimind.py index d3518ad..e8e7f54 100755 --- a/model/model_minimind.py +++ b/model/model_minimind.py @@ -27,6 +27,7 @@ class MiniMindConfig(PretrainedConfig): self.max_position_embeddings = kwargs.get("max_position_embeddings", 32768) self.rms_norm_eps = kwargs.get("rms_norm_eps", 1e-6) self.rope_theta = kwargs.get("rope_theta", 1e6) + self.tie_word_embeddings = kwargs.get("tie_word_embeddings", True) self.inference_rope_scaling = kwargs.get("inference_rope_scaling", False) self.rope_scaling = { "beta_fast": 32, @@ -233,7 +234,7 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin): super().__init__(self.config) self.model = MiniMindModel(self.config) self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) - self.model.embed_tokens.weight = self.lm_head.weight + if self.config.tie_word_embeddings: self.model.embed_tokens.weight = self.lm_head.weight def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, logits_to_keep=0, labels=None, **kwargs): hidden_states, past_key_values, aux_loss = self.model(input_ids, attention_mask, past_key_values, use_cache, **kwargs) diff --git a/scripts/convert_model.py b/scripts/convert_model.py index f409229..a4162a3 100644 --- a/scripts/convert_model.py +++ b/scripts/convert_model.py @@ -51,7 +51,7 @@ def convert_torch2transformers(torch_path, transformers_path, dtype=torch.float1 "max_position_embeddings": lm_config.max_position_embeddings, "rms_norm_eps": lm_config.rms_norm_eps, "rope_theta": lm_config.rope_theta, - "tie_word_embeddings": True + "tie_word_embeddings": lm_config.tie_word_embeddings } if not lm_config.use_moe: qwen_config = Qwen3Config(