From 28cc44579a467c96f83532f677fc013b78b43517 Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Thu, 23 Oct 2025 20:22:42 +0800 Subject: [PATCH] [feat] convert2llama --- scripts/convert_model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/scripts/convert_model.py b/scripts/convert_model.py index a6e22f8..b479825 100644 --- a/scripts/convert_model.py +++ b/scripts/convert_model.py @@ -12,7 +12,7 @@ warnings.filterwarnings('ignore', category=UserWarning) # MoE模型需使用此函数转换 -def convert_torch2transformers_minimind(torch_path, transformers_path, dtype=torch.bfloat16): +def convert_torch2transformers_minimind(torch_path, transformers_path, dtype=torch.float16): MiniMindConfig.register_for_auto_class() MiniMindForCausalLM.register_for_auto_class("AutoModelForCausalLM") lm_model = MiniMindForCausalLM(lm_config) @@ -29,7 +29,7 @@ def convert_torch2transformers_minimind(torch_path, transformers_path, dtype=tor # LlamaForCausalLM结构兼容第三方生态 -def convert_torch2transformers_llama(torch_path, transformers_path, dtype=torch.bfloat16): +def convert_torch2transformers_llama(torch_path, transformers_path, dtype=torch.float16): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') state_dict = torch.load(torch_path, map_location=device) llama_config = LlamaConfig( @@ -39,9 +39,10 @@ def convert_torch2transformers_llama(torch_path, transformers_path, dtype=torch. num_hidden_layers=lm_config.num_hidden_layers, num_attention_heads=lm_config.num_attention_heads, num_key_value_heads=lm_config.num_key_value_heads, - max_position_embeddings=lm_config.max_seq_len, + max_position_embeddings=lm_config.max_position_embeddings, rms_norm_eps=lm_config.rms_norm_eps, rope_theta=lm_config.rope_theta, + tie_word_embeddings=True ) llama_model = LlamaForCausalLM(llama_config) llama_model.load_state_dict(state_dict, strict=False) @@ -68,7 +69,7 @@ if __name__ == '__main__': transformers_path = '../MiniMind2' - convert_torch2transformers_minimind(torch_path, transformers_path) + convert_torch2transformers_llama(torch_path, transformers_path) # # # convert transformers to torch model # # convert_transformers2torch(transformers_path, torch_path)