[feat] compatible tokenizer

This commit is contained in:
jingyaogong 2025-12-31 10:26:46 +08:00
parent 4a5c9f5ece
commit 936d105e9b

View File

@ -1,5 +1,6 @@
import os
import sys
import json
__package__ = "scripts"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
@ -25,6 +26,9 @@ def convert_torch2transformers_minimind(torch_path, transformers_path, dtype=tor
lm_model.save_pretrained(transformers_path, safe_serialization=False)
tokenizer = AutoTokenizer.from_pretrained('../model/')
tokenizer.save_pretrained(transformers_path)
# 兼容transformers-5.0的写法
config_path = os.path.join(transformers_path, "tokenizer_config.json")
json.dump({**json.load(open(config_path, 'r', encoding='utf-8')), "tokenizer_class": "PreTrainedTokenizerFast", "extra_special_tokens": {}}, open(config_path, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
print(f"模型已保存为 Transformers-MiniMind 格式: {transformers_path}")
@ -52,6 +56,9 @@ def convert_torch2transformers_llama(torch_path, transformers_path, dtype=torch.
print(f'模型参数: {model_params / 1e6} 百万 = {model_params / 1e9} B (Billion)')
tokenizer = AutoTokenizer.from_pretrained('../model/')
tokenizer.save_pretrained(transformers_path)
# 兼容transformers-5.0的写法
config_path = os.path.join(transformers_path, "tokenizer_config.json")
json.dump({**json.load(open(config_path, 'r', encoding='utf-8')), "tokenizer_class": "PreTrainedTokenizerFast", "extra_special_tokens": {}}, open(config_path, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
print(f"模型已保存为 Transformers-Llama 格式: {transformers_path}")