mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
[feat] convert2llama
This commit is contained in:
parent
a82526da11
commit
28cc44579a
@ -12,7 +12,7 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
|
||||
# MoE模型需使用此函数转换
|
||||
def convert_torch2transformers_minimind(torch_path, transformers_path, dtype=torch.bfloat16):
|
||||
def convert_torch2transformers_minimind(torch_path, transformers_path, dtype=torch.float16):
|
||||
MiniMindConfig.register_for_auto_class()
|
||||
MiniMindForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
||||
lm_model = MiniMindForCausalLM(lm_config)
|
||||
@ -29,7 +29,7 @@ def convert_torch2transformers_minimind(torch_path, transformers_path, dtype=tor
|
||||
|
||||
|
||||
# LlamaForCausalLM结构兼容第三方生态
|
||||
def convert_torch2transformers_llama(torch_path, transformers_path, dtype=torch.bfloat16):
|
||||
def convert_torch2transformers_llama(torch_path, transformers_path, dtype=torch.float16):
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
state_dict = torch.load(torch_path, map_location=device)
|
||||
llama_config = LlamaConfig(
|
||||
@ -39,9 +39,10 @@ def convert_torch2transformers_llama(torch_path, transformers_path, dtype=torch.
|
||||
num_hidden_layers=lm_config.num_hidden_layers,
|
||||
num_attention_heads=lm_config.num_attention_heads,
|
||||
num_key_value_heads=lm_config.num_key_value_heads,
|
||||
max_position_embeddings=lm_config.max_seq_len,
|
||||
max_position_embeddings=lm_config.max_position_embeddings,
|
||||
rms_norm_eps=lm_config.rms_norm_eps,
|
||||
rope_theta=lm_config.rope_theta,
|
||||
tie_word_embeddings=True
|
||||
)
|
||||
llama_model = LlamaForCausalLM(llama_config)
|
||||
llama_model.load_state_dict(state_dict, strict=False)
|
||||
@ -68,7 +69,7 @@ if __name__ == '__main__':
|
||||
|
||||
transformers_path = '../MiniMind2'
|
||||
|
||||
convert_torch2transformers_minimind(torch_path, transformers_path)
|
||||
convert_torch2transformers_llama(torch_path, transformers_path)
|
||||
|
||||
# # # convert transformers to torch model
|
||||
# # convert_transformers2torch(transformers_path, torch_path)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user