mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
[feat] compatible tokenizer
This commit is contained in:
parent
4a5c9f5ece
commit
936d105e9b
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
__package__ = "scripts"
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
@ -25,6 +26,9 @@ def convert_torch2transformers_minimind(torch_path, transformers_path, dtype=tor
|
||||
lm_model.save_pretrained(transformers_path, safe_serialization=False)
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/')
|
||||
tokenizer.save_pretrained(transformers_path)
|
||||
# 兼容transformers-5.0的写法
|
||||
config_path = os.path.join(transformers_path, "tokenizer_config.json")
|
||||
json.dump({**json.load(open(config_path, 'r', encoding='utf-8')), "tokenizer_class": "PreTrainedTokenizerFast", "extra_special_tokens": {}}, open(config_path, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
|
||||
print(f"模型已保存为 Transformers-MiniMind 格式: {transformers_path}")
|
||||
|
||||
|
||||
@ -52,6 +56,9 @@ def convert_torch2transformers_llama(torch_path, transformers_path, dtype=torch.
|
||||
print(f'模型参数: {model_params / 1e6} 百万 = {model_params / 1e9} B (Billion)')
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/')
|
||||
tokenizer.save_pretrained(transformers_path)
|
||||
# 兼容transformers-5.0的写法
|
||||
config_path = os.path.join(transformers_path, "tokenizer_config.json")
|
||||
json.dump({**json.load(open(config_path, 'r', encoding='utf-8')), "tokenizer_class": "PreTrainedTokenizerFast", "extra_special_tokens": {}}, open(config_path, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
|
||||
print(f"模型已保存为 Transformers-Llama 格式: {transformers_path}")
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user