mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
[update] tie embedding
This commit is contained in:
parent
1ea113ea2c
commit
5704766352
@ -27,6 +27,7 @@ class MiniMindConfig(PretrainedConfig):
|
||||
self.max_position_embeddings = kwargs.get("max_position_embeddings", 32768)
|
||||
self.rms_norm_eps = kwargs.get("rms_norm_eps", 1e-6)
|
||||
self.rope_theta = kwargs.get("rope_theta", 1e6)
|
||||
self.tie_word_embeddings = kwargs.get("tie_word_embeddings", True)
|
||||
self.inference_rope_scaling = kwargs.get("inference_rope_scaling", False)
|
||||
self.rope_scaling = {
|
||||
"beta_fast": 32,
|
||||
@ -233,7 +234,7 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
|
||||
super().__init__(self.config)
|
||||
self.model = MiniMindModel(self.config)
|
||||
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
||||
self.model.embed_tokens.weight = self.lm_head.weight
|
||||
if self.config.tie_word_embeddings: self.model.embed_tokens.weight = self.lm_head.weight
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, logits_to_keep=0, labels=None, **kwargs):
|
||||
hidden_states, past_key_values, aux_loss = self.model(input_ids, attention_mask, past_key_values, use_cache, **kwargs)
|
||||
|
||||
@ -51,7 +51,7 @@ def convert_torch2transformers(torch_path, transformers_path, dtype=torch.float1
|
||||
"max_position_embeddings": lm_config.max_position_embeddings,
|
||||
"rms_norm_eps": lm_config.rms_norm_eps,
|
||||
"rope_theta": lm_config.rope_theta,
|
||||
"tie_word_embeddings": True
|
||||
"tie_word_embeddings": lm_config.tie_word_embeddings
|
||||
}
|
||||
if not lm_config.use_moe:
|
||||
qwen_config = Qwen3Config(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user