mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
[update] tie embedding
This commit is contained in:
@@ -27,6 +27,7 @@ class MiniMindConfig(PretrainedConfig):
|
|||||||
self.max_position_embeddings = kwargs.get("max_position_embeddings", 32768)
|
self.max_position_embeddings = kwargs.get("max_position_embeddings", 32768)
|
||||||
self.rms_norm_eps = kwargs.get("rms_norm_eps", 1e-6)
|
self.rms_norm_eps = kwargs.get("rms_norm_eps", 1e-6)
|
||||||
self.rope_theta = kwargs.get("rope_theta", 1e6)
|
self.rope_theta = kwargs.get("rope_theta", 1e6)
|
||||||
|
self.tie_word_embeddings = kwargs.get("tie_word_embeddings", True)
|
||||||
self.inference_rope_scaling = kwargs.get("inference_rope_scaling", False)
|
self.inference_rope_scaling = kwargs.get("inference_rope_scaling", False)
|
||||||
self.rope_scaling = {
|
self.rope_scaling = {
|
||||||
"beta_fast": 32,
|
"beta_fast": 32,
|
||||||
@@ -233,7 +234,7 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
|
|||||||
super().__init__(self.config)
|
super().__init__(self.config)
|
||||||
self.model = MiniMindModel(self.config)
|
self.model = MiniMindModel(self.config)
|
||||||
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
||||||
self.model.embed_tokens.weight = self.lm_head.weight
|
if self.config.tie_word_embeddings: self.model.embed_tokens.weight = self.lm_head.weight
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, logits_to_keep=0, labels=None, **kwargs):
|
def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, logits_to_keep=0, labels=None, **kwargs):
|
||||||
hidden_states, past_key_values, aux_loss = self.model(input_ids, attention_mask, past_key_values, use_cache, **kwargs)
|
hidden_states, past_key_values, aux_loss = self.model(input_ids, attention_mask, past_key_values, use_cache, **kwargs)
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ def convert_torch2transformers(torch_path, transformers_path, dtype=torch.float1
|
|||||||
"max_position_embeddings": lm_config.max_position_embeddings,
|
"max_position_embeddings": lm_config.max_position_embeddings,
|
||||||
"rms_norm_eps": lm_config.rms_norm_eps,
|
"rms_norm_eps": lm_config.rms_norm_eps,
|
||||||
"rope_theta": lm_config.rope_theta,
|
"rope_theta": lm_config.rope_theta,
|
||||||
"tie_word_embeddings": True
|
"tie_word_embeddings": lm_config.tie_word_embeddings
|
||||||
}
|
}
|
||||||
if not lm_config.use_moe:
|
if not lm_config.use_moe:
|
||||||
qwen_config = Qwen3Config(
|
qwen_config = Qwen3Config(
|
||||||
|
|||||||
Reference in New Issue
Block a user