[update] tie embedding

This commit is contained in:
jingyaogong 2026-04-19 21:57:28 +08:00
parent 1ea113ea2c
commit 5704766352
2 changed files with 3 additions and 2 deletions

View File

@ -27,6 +27,7 @@ class MiniMindConfig(PretrainedConfig):
self.max_position_embeddings = kwargs.get("max_position_embeddings", 32768)
self.rms_norm_eps = kwargs.get("rms_norm_eps", 1e-6)
self.rope_theta = kwargs.get("rope_theta", 1e6)
self.tie_word_embeddings = kwargs.get("tie_word_embeddings", True)
self.inference_rope_scaling = kwargs.get("inference_rope_scaling", False)
self.rope_scaling = {
"beta_fast": 32,
@ -233,7 +234,7 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
super().__init__(self.config)
self.model = MiniMindModel(self.config)
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.model.embed_tokens.weight = self.lm_head.weight
if self.config.tie_word_embeddings: self.model.embed_tokens.weight = self.lm_head.weight
def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, logits_to_keep=0, labels=None, **kwargs):
hidden_states, past_key_values, aux_loss = self.model(input_ids, attention_mask, past_key_values, use_cache, **kwargs)

View File

@ -51,7 +51,7 @@ def convert_torch2transformers(torch_path, transformers_path, dtype=torch.float1
"max_position_embeddings": lm_config.max_position_embeddings,
"rms_norm_eps": lm_config.rms_norm_eps,
"rope_theta": lm_config.rope_theta,
"tie_word_embeddings": True
"tie_word_embeddings": lm_config.tie_word_embeddings
}
if not lm_config.use_moe:
qwen_config = Qwen3Config(