mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
[update] rename train tokenizer
This commit is contained in:
parent
9830915d87
commit
07364c3fbe
@ -1,63 +1,40 @@
|
|||||||
import random
|
# 注:不建议再重复训练tokenizer(“词典”),MiniMind已自带,此脚本仅供学习和参考。基于不同词典训练的模型将导致输出完全不统一,降低社区的模型复用性
|
||||||
import json
|
# Note: It is not recommended to re-train the tokenizer. MiniMind already includes one. This script is for learning and reference only. Training models with different tokenizers will lead to inconsistent outputs and reduce model reusability in the community.
|
||||||
from tokenizers import (
|
|
||||||
decoders,
|
|
||||||
models,
|
|
||||||
pre_tokenizers,
|
|
||||||
trainers,
|
|
||||||
Tokenizer,
|
|
||||||
)
|
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
|
from tokenizers import decoders, models, pre_tokenizers, trainers, Tokenizer
|
||||||
|
|
||||||
random.seed(42)
|
DATA_PATH = '../dataset/pretrain_hq.jsonl'
|
||||||
|
TOKENIZER_DIR = '../model_learn_tokenizer/'
|
||||||
|
VOCAB_SIZE = 6400
|
||||||
|
|
||||||
|
def get_texts(data_path):
|
||||||
def train_tokenizer():
|
with open(data_path, 'r', encoding='utf-8') as f:
|
||||||
# 读取JSONL文件并提取文本数据
|
for i, line in enumerate(f):
|
||||||
def read_texts_from_jsonl(file_path):
|
if i >= 10000: break # 实验性,可只用前10000行测试
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
|
||||||
for line in f:
|
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
yield data['text']
|
yield data['text']
|
||||||
|
|
||||||
data_path = '../dataset/pretrain_hq.jsonl'
|
def train_tokenizer(data_path, tokenizer_dir, vocab_size):
|
||||||
|
|
||||||
# 初始化tokenizer
|
|
||||||
tokenizer = Tokenizer(models.BPE())
|
tokenizer = Tokenizer(models.BPE())
|
||||||
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
||||||
|
|
||||||
# 定义特殊token
|
|
||||||
special_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
|
|
||||||
|
|
||||||
# 设置训练器并添加特殊token
|
|
||||||
trainer = trainers.BpeTrainer(
|
trainer = trainers.BpeTrainer(
|
||||||
vocab_size=6400,
|
vocab_size=vocab_size,
|
||||||
special_tokens=special_tokens, # 确保这三个token被包含
|
special_tokens=["<|endoftext|>", "<|im_start|>", "<|im_end|>"],
|
||||||
show_progress=True,
|
show_progress=True,
|
||||||
initial_alphabet=pre_tokenizers.ByteLevel.alphabet()
|
initial_alphabet=pre_tokenizers.ByteLevel.alphabet()
|
||||||
)
|
)
|
||||||
|
texts = get_texts(data_path)
|
||||||
# 读取文本数据
|
|
||||||
texts = read_texts_from_jsonl(data_path)
|
|
||||||
|
|
||||||
# 训练tokenizer
|
|
||||||
tokenizer.train_from_iterator(texts, trainer=trainer)
|
tokenizer.train_from_iterator(texts, trainer=trainer)
|
||||||
|
|
||||||
# 设置解码器
|
|
||||||
tokenizer.decoder = decoders.ByteLevel()
|
tokenizer.decoder = decoders.ByteLevel()
|
||||||
|
|
||||||
# 检查特殊token的索引
|
|
||||||
assert tokenizer.token_to_id("<|endoftext|>") == 0
|
assert tokenizer.token_to_id("<|endoftext|>") == 0
|
||||||
assert tokenizer.token_to_id("<|im_start|>") == 1
|
assert tokenizer.token_to_id("<|im_start|>") == 1
|
||||||
assert tokenizer.token_to_id("<|im_end|>") == 2
|
assert tokenizer.token_to_id("<|im_end|>") == 2
|
||||||
|
|
||||||
# 保存tokenizer
|
|
||||||
tokenizer_dir = "../model/"
|
|
||||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||||
tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||||
tokenizer.model.save("../model/")
|
tokenizer.model.save(tokenizer_dir)
|
||||||
|
|
||||||
# 手动创建配置文件
|
|
||||||
config = {
|
config = {
|
||||||
"add_bos_token": False,
|
"add_bos_token": False,
|
||||||
"add_eos_token": False,
|
"add_eos_token": False,
|
||||||
@ -102,19 +79,14 @@ def train_tokenizer():
|
|||||||
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' -%}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else -%}\n {{- '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}"
|
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' -%}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else -%}\n {{- '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 保存配置文件
|
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w", encoding="utf-8") as f:
|
||||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w", encoding="utf-8") as config_file:
|
json.dump(config, f, ensure_ascii=False, indent=4)
|
||||||
json.dump(config, config_file, ensure_ascii=False, indent=4)
|
print("Tokenizer training completed.")
|
||||||
|
|
||||||
print("Tokenizer training completed and saved.")
|
|
||||||
|
|
||||||
|
|
||||||
def eval_tokenizer():
|
def eval_tokenizer(tokenizer_dir):
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
|
||||||
# 加载预训练的tokenizer
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("../model/")
|
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "你是一个优秀的聊天机器人,总是给我正确的回应!"},
|
{"role": "system", "content": "你是一个优秀的聊天机器人,总是给我正确的回应!"},
|
||||||
{"role": "user", "content": '你来自哪里?'},
|
{"role": "user", "content": '你来自哪里?'},
|
||||||
@ -124,24 +96,31 @@ def eval_tokenizer():
|
|||||||
messages,
|
messages,
|
||||||
tokenize=False
|
tokenize=False
|
||||||
)
|
)
|
||||||
|
print('-'*100)
|
||||||
print(new_prompt)
|
print(new_prompt)
|
||||||
|
|
||||||
# 获取实际词汇表长度(包括特殊符号)
|
|
||||||
actual_vocab_size = len(tokenizer)
|
|
||||||
print('tokenizer实际词表长度:', actual_vocab_size)
|
|
||||||
|
|
||||||
|
print('-'*100)
|
||||||
|
print('tokenizer词表长度:', len(tokenizer))
|
||||||
model_inputs = tokenizer(new_prompt)
|
model_inputs = tokenizer(new_prompt)
|
||||||
print('encoder长度:', len(model_inputs['input_ids']))
|
print('encoder长度:', len(model_inputs['input_ids']))
|
||||||
|
response = tokenizer.decode(model_inputs['input_ids'], skip_special_tokens=False)
|
||||||
|
print('decoder一致性:', response == new_prompt, "\n")
|
||||||
|
|
||||||
|
|
||||||
|
print('-'*100)
|
||||||
|
print('流式解码(字节缓冲)测试:')
|
||||||
input_ids = model_inputs['input_ids']
|
input_ids = model_inputs['input_ids']
|
||||||
response = tokenizer.decode(input_ids, skip_special_tokens=False)
|
token_cache = []
|
||||||
print('decoder和原始文本是否一致:', response == new_prompt)
|
for tid in input_ids:
|
||||||
|
token_cache.append(tid)
|
||||||
|
current_decode = tokenizer.decode(token_cache)
|
||||||
def main():
|
if current_decode and '\ufffd' not in current_decode:
|
||||||
train_tokenizer()
|
display_ids = token_cache[0] if len(token_cache) == 1 else token_cache
|
||||||
eval_tokenizer()
|
raw_tokens = [tokenizer.convert_ids_to_tokens(int(t)) for t in (token_cache if isinstance(token_cache, list) else [token_cache])]
|
||||||
|
print(f'Token ID: {str(display_ids):15} -> Raw: {str(raw_tokens):20} -> Decode Str: {current_decode}')
|
||||||
|
token_cache = []
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
train_tokenizer(DATA_PATH, TOKENIZER_DIR, VOCAB_SIZE)
|
||||||
|
eval_tokenizer(TOKENIZER_DIR)
|
||||||
Loading…
Reference in New Issue
Block a user