[feat] convert2llama

This commit is contained in:
jingyaogong 2025-10-23 20:22:42 +08:00
parent a82526da11
commit 28cc44579a

View File

@ -12,7 +12,7 @@ warnings.filterwarnings('ignore', category=UserWarning)
# MoE模型需使用此函数转换
def convert_torch2transformers_minimind(torch_path, transformers_path, dtype=torch.bfloat16):
def convert_torch2transformers_minimind(torch_path, transformers_path, dtype=torch.float16):
MiniMindConfig.register_for_auto_class()
MiniMindForCausalLM.register_for_auto_class("AutoModelForCausalLM")
lm_model = MiniMindForCausalLM(lm_config)
@ -29,7 +29,7 @@ def convert_torch2transformers_minimind(torch_path, transformers_path, dtype=tor
# LlamaForCausalLM结构兼容第三方生态
def convert_torch2transformers_llama(torch_path, transformers_path, dtype=torch.bfloat16):
def convert_torch2transformers_llama(torch_path, transformers_path, dtype=torch.float16):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
state_dict = torch.load(torch_path, map_location=device)
llama_config = LlamaConfig(
@ -39,9 +39,10 @@ def convert_torch2transformers_llama(torch_path, transformers_path, dtype=torch.
num_hidden_layers=lm_config.num_hidden_layers,
num_attention_heads=lm_config.num_attention_heads,
num_key_value_heads=lm_config.num_key_value_heads,
max_position_embeddings=lm_config.max_seq_len,
max_position_embeddings=lm_config.max_position_embeddings,
rms_norm_eps=lm_config.rms_norm_eps,
rope_theta=lm_config.rope_theta,
tie_word_embeddings=True
)
llama_model = LlamaForCausalLM(llama_config)
llama_model.load_state_dict(state_dict, strict=False)
@ -68,7 +69,7 @@ if __name__ == '__main__':
transformers_path = '../MiniMind2'
convert_torch2transformers_minimind(torch_path, transformers_path)
convert_torch2transformers_llama(torch_path, transformers_path)
# # # convert transformers to torch model
# # convert_transformers2torch(transformers_path, torch_path)