mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
148 lines
8.6 KiB
Python
148 lines
8.6 KiB
Python
import random
|
||
import json
|
||
from tokenizers import (
|
||
decoders,
|
||
models,
|
||
pre_tokenizers,
|
||
trainers,
|
||
Tokenizer,
|
||
)
|
||
import os
|
||
|
||
random.seed(42)
|
||
|
||
|
||
def train_tokenizer():
|
||
# 读取JSONL文件并提取文本数据
|
||
def read_texts_from_jsonl(file_path):
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
for line in f:
|
||
data = json.loads(line)
|
||
yield data['text']
|
||
|
||
data_path = '../dataset/pretrain_hq.jsonl'
|
||
|
||
# 初始化tokenizer
|
||
tokenizer = Tokenizer(models.BPE())
|
||
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
||
|
||
# 定义特殊token
|
||
special_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
|
||
|
||
# 设置训练器并添加特殊token
|
||
trainer = trainers.BpeTrainer(
|
||
vocab_size=6400,
|
||
special_tokens=special_tokens, # 确保这三个token被包含
|
||
show_progress=True,
|
||
initial_alphabet=pre_tokenizers.ByteLevel.alphabet()
|
||
)
|
||
|
||
# 读取文本数据
|
||
texts = read_texts_from_jsonl(data_path)
|
||
|
||
# 训练tokenizer
|
||
tokenizer.train_from_iterator(texts, trainer=trainer)
|
||
|
||
# 设置解码器
|
||
tokenizer.decoder = decoders.ByteLevel()
|
||
|
||
# 检查特殊token的索引
|
||
assert tokenizer.token_to_id("<|endoftext|>") == 0
|
||
assert tokenizer.token_to_id("<|im_start|>") == 1
|
||
assert tokenizer.token_to_id("<|im_end|>") == 2
|
||
|
||
# 保存tokenizer
|
||
tokenizer_dir = "../model/"
|
||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||
tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||
tokenizer.model.save("../model/")
|
||
|
||
# 手动创建配置文件
|
||
config = {
|
||
"add_bos_token": False,
|
||
"add_eos_token": False,
|
||
"add_prefix_space": False,
|
||
"added_tokens_decoder": {
|
||
"0": {
|
||
"content": "<|endoftext|>",
|
||
"lstrip": False,
|
||
"normalized": False,
|
||
"rstrip": False,
|
||
"single_word": False,
|
||
"special": True
|
||
},
|
||
"1": {
|
||
"content": "<|im_start|>",
|
||
"lstrip": False,
|
||
"normalized": False,
|
||
"rstrip": False,
|
||
"single_word": False,
|
||
"special": True
|
||
},
|
||
"2": {
|
||
"content": "<|im_end|>",
|
||
"lstrip": False,
|
||
"normalized": False,
|
||
"rstrip": False,
|
||
"single_word": False,
|
||
"special": True
|
||
}
|
||
},
|
||
"additional_special_tokens": [],
|
||
"bos_token": "<|im_start|>",
|
||
"clean_up_tokenization_spaces": False,
|
||
"eos_token": "<|im_end|>",
|
||
"legacy": True,
|
||
"model_max_length": 32768,
|
||
"pad_token": "<|endoftext|>",
|
||
"sp_model_kwargs": {},
|
||
"spaces_between_special_tokens": False,
|
||
"tokenizer_class": "PreTrainedTokenizerFast",
|
||
"unk_token": "<|endoftext|>",
|
||
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' -%}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else -%}\n {{- '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}"
|
||
}
|
||
|
||
# 保存配置文件
|
||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w", encoding="utf-8") as config_file:
|
||
json.dump(config, config_file, ensure_ascii=False, indent=4)
|
||
|
||
print("Tokenizer training completed and saved.")
|
||
|
||
|
||
def eval_tokenizer():
|
||
from transformers import AutoTokenizer
|
||
|
||
# 加载预训练的tokenizer
|
||
tokenizer = AutoTokenizer.from_pretrained("../model/")
|
||
|
||
messages = [
|
||
{"role": "system", "content": "你是一个优秀的聊天机器人,总是给我正确的回应!"},
|
||
{"role": "user", "content": '你来自哪里?'},
|
||
{"role": "assistant", "content": '我来自地球'}
|
||
]
|
||
new_prompt = tokenizer.apply_chat_template(
|
||
messages,
|
||
tokenize=False
|
||
)
|
||
print(new_prompt)
|
||
|
||
# 获取实际词汇表长度(包括特殊符号)
|
||
actual_vocab_size = len(tokenizer)
|
||
print('tokenizer实际词表长度:', actual_vocab_size)
|
||
|
||
model_inputs = tokenizer(new_prompt)
|
||
print('encoder长度:', len(model_inputs['input_ids']))
|
||
|
||
input_ids = model_inputs['input_ids']
|
||
response = tokenizer.decode(input_ids, skip_special_tokens=False)
|
||
print('decoder和原始文本是否一致:', response == new_prompt)
|
||
|
||
|
||
def main():
|
||
train_tokenizer()
|
||
eval_tokenizer()
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|