add minimind2

This commit is contained in:
gongjy
2025-02-09 23:49:47 +08:00
parent 6e9cd28ef9
commit 58e3af0359
260 changed files with 4773 additions and 40206 deletions
+30
View File
@@ -0,0 +1,30 @@
from openai import OpenAI
client = OpenAI(
api_key="none",
base_url="http://localhost:8998/v1"
)
stream = True
conversation_history_origin = []
conversation_history = conversation_history_origin.copy()
while True:
conversation_history = conversation_history_origin.copy()
query = input('[Q]: ')
conversation_history.append({"role": "user", "content": query})
response = client.chat.completions.create(
model="minimind",
messages=conversation_history,
stream=stream
)
if not stream:
assistant_res = response.choices[0].message.content
print('[A]: ', assistant_res)
else:
print('[A]: ', end='')
assistant_res = ''
for chunk in response:
print(chunk.choices[0].delta.content or "", end="")
assistant_res += chunk.choices[0].delta.content or ""
conversation_history.append({"role": "assistant", "content": assistant_res})
print('\n\n')
+62
View File
@@ -0,0 +1,62 @@
import torch
import warnings
import sys
import os
__package__ = "scripts"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.LMConfig import LMConfig
from model.model import MiniMindLM
warnings.filterwarnings('ignore', category=UserWarning)
def convert_torch2transformers(torch_path, transformers_path):
def export_tokenizer(transformers_path):
tokenizer = AutoTokenizer.from_pretrained('../model/minimind_tokenizer')
tokenizer.save_pretrained(transformers_path)
LMConfig.register_for_auto_class()
MiniMindLM.register_for_auto_class("AutoModelForCausalLM")
lm_model = MiniMindLM(lm_config)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
state_dict = torch.load(torch_path, map_location=device)
lm_model.load_state_dict(state_dict, strict=False)
model_params = sum(p.numel() for p in lm_model.parameters() if p.requires_grad)
print(f'模型参数: {model_params / 1e6} 百万 = {model_params / 1e9} B (Billion)')
lm_model.save_pretrained(transformers_path, safe_serialization=False)
export_tokenizer(transformers_path)
print(f"模型已保存为 Transformers 格式: {transformers_path}")
def convert_transformers2torch(transformers_path, torch_path):
model = AutoModelForCausalLM.from_pretrained(transformers_path, trust_remote_code=True)
torch.save(model.state_dict(), torch_path)
print(f"模型已保存为 PyTorch 格式: {torch_path}")
# don't need to use
def push_to_hf(export_model_path):
def init_model():
tokenizer = AutoTokenizer.from_pretrained('../model/minimind_tokenizer')
model = AutoModelForCausalLM.from_pretrained(export_model_path, trust_remote_code=True)
return model, tokenizer
model, tokenizer = init_model()
# model.push_to_hub(model_path)
# tokenizer.push_to_hub(model_path, safe_serialization=False)
if __name__ == '__main__':
lm_config = LMConfig(dim=512, n_layers=8, max_seq_len=8192, use_moe=False)
torch_path = f"../out/reason_{lm_config.dim}{'_moe' if lm_config.use_moe else ''}.pth"
transformers_path = '../MiniMind2-Small-R1'
# convert torch to transformers model
convert_torch2transformers(torch_path, transformers_path)
# # convert transformers to torch model
# convert_transformers2torch(transformers_path, torch_path)
+185
View File
@@ -0,0 +1,185 @@
import csv
import glob
import os
import re
import json
import jsonlines
import pandas as pd
from tqdm import tqdm
bos_token = "<s>"
eos_token = "</s>"
def pretrain_process():
# 定义输入和输出路径
input_dir = '../CCI3-HQ/data'
output_file = '../dataset/pretrain_data_hq.csv'
jsonl_files = glob.glob(os.path.join(input_dir, 'part_*.jsonl'))
total_lines = 0
print("正在计算总行数...")
for file in jsonl_files:
with open(file, 'r', encoding='utf-8') as f:
for _ in f:
total_lines += 1
with open(output_file, 'w', newline='', encoding='utf-8') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(['text', 'score']) # 写入表头
for jsonl_file in jsonl_files:
with open(jsonl_file, 'r', encoding='utf-8') as f:
for line in tqdm(f, desc=f'处理 {os.path.basename(jsonl_file)}', total=total_lines, unit='',
leave=False):
try:
data = json.loads(line)
text = data.get('text', '')
score = data.get('score', 0)
if len(text) <= 512 and score > 3.5:
writer.writerow([text, score])
except json.JSONDecodeError:
continue
print(f"筛选完成,结果已保存到 {output_file}")
def sft_process():
sft_file_name = 'sft_data.csv'
def process_and_write_data(data):
q_lst, a_lst, history_lst = [], [], []
for per in data:
history, q, a = per['history'], per['q'], per['a']
if not q or not a:
continue
history_len = sum(len(s) for s in history)
message_len = history_len + len(q) + len(a)
if message_len < 70 or message_len > 512:
continue
q_lst.append(q)
a_lst.append(a)
history_lst.append(history)
df = pd.DataFrame({'history': history_lst, 'q': q_lst, 'a': a_lst})
df.to_csv(f'../dataset/{sft_file_name}',
mode='a', header=False, index=False,
lineterminator='\r\n', escapechar='\\', encoding='utf-8')
chunk_size = 1000
data = []
with open(f'../dataset/{sft_file_name}', 'w', encoding='utf-8') as f:
f.write('history,q,a\n')
# sft_path = ['/root/shared-nvme/sft_data_zh.jsonl', '/root/shared-nvme/sft_data_en.jsonl']
sft_path = ['/root/shared-nvme/sft_data_en.jsonl']
chunk_num = 0
for path in sft_path:
with jsonlines.open(path) as reader:
for idx, obj in enumerate(reader):
try:
data.append({
'history': obj.get('history', ''),
'q': obj.get('input', '') + obj.get('q', ''),
'a': obj.get('output', '') + obj.get('a', '')
})
if len(data) >= chunk_size:
chunk_num += 1
process_and_write_data(data)
data = []
if chunk_num % 100 == 0:
print(f'chunk:{chunk_num} process end')
except jsonlines.InvalidLineError as e:
print(f"Skipping invalid JSON line {idx + 1}: {e}")
continue
if data:
process_and_write_data(data)
data = []
def rl_process():
# 偏好数据默认只用中文(建议)
input_paths = [
# "../dataset/dpo_en.json",
"../dataset/dpo_zh.json"
]
output_path = "../dataset/dpo_data.jsonl" # 修改输出文件扩展名为 .jsonl
all_converted = []
for input_path in input_paths:
with open(input_path, "r", encoding="utf-8") as f:
data = json.load(f) # data is likely a list
for item in data:
new_data = {
"chosen": [],
"rejected": []
}
for turn in item["conversations"]:
role = "user" if turn["from"] == "human" else "assistant"
message = {"role": role, "content": turn["value"]}
new_data["chosen"].append(message)
new_data["rejected"].append(message)
new_data["chosen"].append({
"role": "assistant",
"content": item["chosen"]["value"]
})
new_data["rejected"].append({
"role": "assistant",
"content": item["rejected"]["value"]
})
all_converted.append(new_data)
with open(output_path, "w", encoding="utf-8") as f:
for item in all_converted:
f.write(json.dumps(item, ensure_ascii=False) + "\n")
def lora_dataset():
import json
import csv
# 读取JSON文件
with open('../dataset/Chinese-medical-dialogue.json', 'r', encoding='utf-8') as f:
data = json.load(f)
# 准备CSV数据
csv_data = []
for item in data:
# 提取input和output并去除首尾空白
q = item['input'].strip()
a = item['output'].strip()
# 检查长度是否符合要求
if len(q) + len(a) < 160:
csv_data.append({
'history': '[]',
'q': q,
'a': a
})
# 写入CSV文件
with open('../dataset/medical_sft.csv', 'w', newline='', encoding='utf-8') as csvfile:
fieldnames = ['history', 'q', 'a']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(csv_data)
print(f'转换完成,共处理 {len(csv_data)} 条有效数据')
if __name__ == "__main__":
################
# 1: pretrain
# 2: sft
# 3: RL
################
process_type = 4
if process_type == 1:
pretrain_process()
if process_type == 2:
sft_process()
if process_type == 3:
rl_process()
if process_type == 4:
lora_dataset()
+97
View File
@@ -0,0 +1,97 @@
# from datasets import load_dataset
#
# dataset_paths = [
# ['ceval/ceval-exam',
# ['computer_network', 'operating_system', 'computer_architecture', 'college_programming', 'college_physics',
# 'college_chemistry', 'advanced_mathematics', 'probability_and_statistics', 'discrete_mathematics',
# 'electrical_engineer', 'metrology_engineer', 'high_school_mathematics', 'high_school_physics',
# 'high_school_chemistry', 'high_school_biology', 'middle_school_mathematics', 'middle_school_biology',
# 'middle_school_physics', 'middle_school_chemistry', 'veterinary_medicine', 'college_economics',
# 'business_administration', 'marxism', 'mao_zedong_thought', 'education_science', 'teacher_qualification',
# 'high_school_politics', 'high_school_geography', 'middle_school_politics', 'middle_school_geography',
# 'modern_chinese_history', 'ideological_and_moral_cultivation', 'logic', 'law', 'chinese_language_and_literature',
# 'art_studies', 'professional_tour_guide', 'legal_professional', 'high_school_chinese', 'high_school_history',
# 'middle_school_history', 'civil_servant', 'sports_science', 'plant_protection', 'basic_medicine',
# 'clinical_medicine', 'urban_and_rural_planner', 'accountant', 'fire_engineer',
# 'environmental_impact_assessment_engineer', 'tax_accountant', 'physician']], # ceval*
# ['haonan-li/cmmlu', [
# 'agronomy', 'anatomy', 'ancient_chinese', 'arts', 'astronomy', 'business_ethics',
# 'chinese_civil_service_exam', 'chinese_driving_rule', 'chinese_food_culture',
# 'chinese_foreign_policy', 'chinese_history', 'chinese_literature',
# 'chinese_teacher_qualification', 'clinical_knowledge', 'college_actuarial_science',
# 'college_education', 'college_engineering_hydrology', 'college_law',
# 'college_mathematics', 'college_medical_statistics', 'college_medicine',
# 'computer_science', 'computer_security', 'conceptual_physics',
# 'construction_project_management', 'economics', 'education', 'electrical_engineering',
# 'elementary_chinese', 'elementary_commonsense', 'elementary_information_and_technology',
# 'elementary_mathematics', 'ethnology', 'food_science', 'genetics', 'global_facts',
# 'high_school_biology', 'high_school_chemistry', 'high_school_geography',
# 'high_school_mathematics', 'high_school_physics', 'high_school_politics',
# 'human_sexuality', 'international_law', 'journalism', 'jurisprudence',
# 'legal_and_moral_basis', 'logical', 'machine_learning', 'management', 'marketing',
# 'marxist_theory', 'modern_chinese', 'nutrition', 'philosophy', 'professional_accounting',
# 'professional_law', 'professional_medicine', 'professional_psychology',
# 'public_relations', 'security_study', 'sociology', 'sports_science',
# 'traditional_chinese_medicine', 'virology', 'world_history', 'world_religions'
# ]], # cmmlu*
# ['tyouisen/aclue',
# ['polysemy_resolution', 'poetry_sentiment_analysis', 'named_entity_recognition', 'basic_ancient_chinese',
# 'poetry_context_prediction', 'sentence_segmentation', 'couplet_prediction', 'poetry_appreciate',
# 'ancient_chinese_culture', 'ancient_phonetics', 'homographic_character_resolution', 'ancient_literature',
# 'ancient_medical', 'poetry_quality_assessment', 'reading_comprehension']], # aclue
# ['juletxara/mgsm', ['zh']], # mgsm_direct_zh
# ['openbookqa', ['main']], # openbookqa
# ['ZoneTwelve/tmmluplus',
# ['dentistry', 'traditional_chinese_medicine_clinical_medicine', 'clinical_psychology', 'technical',
# 'culinary_skills', 'mechanical', 'logic_reasoning', 'real_estate', 'general_principles_of_law', 'finance_banking',
# 'anti_money_laundering', 'ttqav2', 'marketing_management', 'business_management', 'organic_chemistry',
# 'advance_chemistry', 'physics', 'secondary_physics', 'human_behavior', 'national_protection', 'jce_humanities',
# 'politic_science', 'agriculture', 'official_document_management', 'financial_analysis', 'pharmacy',
# 'educational_psychology', 'statistics_and_machine_learning', 'management_accounting', 'introduction_to_law',
# 'computer_science', 'veterinary_pathology', 'accounting', 'fire_science', 'optometry', 'insurance_studies',
# 'pharmacology', 'taxation', 'education_(profession_level)', 'economics', 'veterinary_pharmacology',
# 'nautical_science', 'occupational_therapy_for_psychological_disorders', 'trust_practice', 'geography_of_taiwan',
# 'physical_education', 'auditing', 'administrative_law', 'basic_medical_science', 'macroeconomics', 'trade',
# 'chinese_language_and_literature', 'tve_design', 'junior_science_exam', 'junior_math_exam', 'junior_chinese_exam',
# 'junior_social_studies', 'tve_mathematics', 'tve_chinese_language', 'tve_natural_sciences', 'junior_chemistry',
# 'music', 'education', 'three_principles_of_people', 'taiwanese_hokkien', 'engineering_math', 'linear_algebra']]
# # tmmluplus
#
# ]
#
# for dataset_path in dataset_paths:
# for dataset_name in dataset_path[1]:
# datasets = load_dataset(dataset_path[0], dataset_name, cache_dir='./test_dataset_cache')
#
# """
# export HF_HUB_OFFLINE=1 && lm_eval --model hf --model_args pretrained=/xxx/minimind/minimind-v2-small/,device=cuda,dtype=auto --tasks ceval* --batch_size 8 --trust_remote_code
# """
"""
$env:HF_HUB_OFFLINE=1; lm_eval --model hf --model_args pretrained=../minimind-v2-small/,device=cuda,dtype=auto --tasks ceval* --batch_size 8 --trust_remote_code
"""
import subprocess
# 定义要执行的命令
command = (
'set HF_HUB_OFFLINE=1 & '
'lm_eval --model hf --model_args pretrained=../minimind-v2-small/,device=cuda,dtype=auto '
'--tasks ceval* --batch_size 8 --trust_remote_code'
)
# 使用 subprocess 执行命令
try:
process = subprocess.run(
command,
shell=True,
check=True,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
# 打印命令的输出
print("STDOUT:", process.stdout)
print("STDERR:", process.stderr)
except subprocess.CalledProcessError as e:
print(f"命令执行失败,返回码: {e.returncode}")
print("STDERR:", e.stderr)
+164
View File
@@ -0,0 +1,164 @@
import argparse
import json
import os
import sys
__package__ = "scripts"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import time
import torch
import warnings
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.LMConfig import LMConfig
from model.model import MiniMindLM
from model.model_lora import apply_lora, load_lora
warnings.filterwarnings('ignore')
app = FastAPI()
def init_model(args):
tokenizer = AutoTokenizer.from_pretrained('../model/minimind_tokenizer')
if args.load == 0:
moe_path = '_moe' if args.use_moe else ''
modes = {0: 'pretrain', 1: 'full_sft', 2: 'full_dist', 3: 'rlhf'}
ckp = f'../{args.out_dir}/{modes[args.model_mode]}_{args.dim}{moe_path}.pth'
model = MiniMindLM(LMConfig(
dim=args.dim,
n_layers=args.n_layers,
max_seq_len=args.max_seq_len,
use_moe=args.use_moe
))
state_dict = torch.load(ckp, map_location=device)
model.load_state_dict({k: v for k, v in state_dict.items() if 'mask' not in k}, strict=True)
if args.lora_name != 'None':
apply_lora(model)
load_lora(model, f'../{args.out_dir}/{args.lora_name}_{args.dim}.pth')
else:
model = AutoModelForCausalLM.from_pretrained(
'./MiniMind2',
trust_remote_code=True
)
print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M(illion)')
return model.eval().to(device), tokenizer
class ChatRequest(BaseModel):
model: str
messages: list
temperature: float = 0.7
top_p: int = 0.92
max_tokens: int = 8192
stream: bool = False
def generate_stream_response(messages, temperature, top_p, max_tokens):
try:
new_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)[-max_tokens:]
x = tokenizer(new_prompt).data['input_ids']
x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
with torch.no_grad():
res_y = model.generate(
x,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stream=True,
rp=1.,
pad_token_id=tokenizer.pad_token_id
)
history_idx = 0
for y in res_y:
answer = tokenizer.decode(y[0].tolist(), skip_special_tokens=True)
if (answer and answer[-1] == '') or not answer:
continue
delta = answer[history_idx:]
history_idx = len(answer)
json_data = {
'id': f'chatcmpl-{int(time.time())}',
'object': 'chat.completion.chunk',
'created': int(time.time()),
'model': 'minimind',
'choices': [{'index': 0, 'delta': {'content': delta}, 'finish_reason': None}]
}
yield f"data: {json.dumps(json_data)}\n\n"
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatRequest):
try:
if request.stream:
return StreamingResponse(
generate_stream_response(
messages=request.messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens
),
media_type="text/event-stream"
)
else:
new_prompt = tokenizer.apply_chat_template(
request.messages,
tokenize=False,
add_generation_prompt=True
)[-request.max_tokens:]
x = tokenizer(new_prompt).data['input_ids']
x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
with torch.no_grad():
res_y = model.generate(
x,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
stream=False,
rp=1.,
pad_token_id=tokenizer.pad_token_id
)
answer = tokenizer.decode(res_y.squeeze()[x.shape[1]:].tolist(), skip_special_tokens=True)
return {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": "minimind",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": answer},
"finish_reason": "stop"
}
]
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Server for MiniMind")
parser.add_argument('--out_dir', default='out', type=str)
parser.add_argument('--lora_name', default='None', type=str)
parser.add_argument('--dim', default=512, type=int)
parser.add_argument('--n_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=8192, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--load', default=0, type=int, help="0: 从原生torch权重,1: 利用transformers加载")
parser.add_argument('--model_mode', default=1, type=int, help="0: 预训练模型,1: SFT-Chat模型,2: RLHF-Chat模型")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model, tokenizer = init_model(parser.parse_args())
uvicorn.run(app, host="0.0.0.0", port=8998)
+152
View File
@@ -0,0 +1,152 @@
import random
from tqdm import tqdm
from transformers import AutoTokenizer
import json
from datasets import load_dataset
from tokenizers import (
decoders,
models,
normalizers,
pre_tokenizers,
processors,
trainers,
Tokenizer,
)
import os
random.seed(42)
def train_tokenizer():
# 读取JSONL文件并提取文本数据
def read_texts_from_jsonl(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
yield data['text']
data_path = '../dataset/tokenizer_train.jsonl'
# 初始化tokenizer
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
# 定义特殊token
special_tokens = ["<unk>", "<s>", "</s>"]
# 设置训练器并添加特殊token
trainer = trainers.BpeTrainer(
vocab_size=6400,
special_tokens=special_tokens, # 确保这三个token被包含
show_progress=True,
initial_alphabet=pre_tokenizers.ByteLevel.alphabet()
)
# 读取文本数据
texts = read_texts_from_jsonl(data_path)
# 训练tokenizer
tokenizer.train_from_iterator(texts, trainer=trainer)
# 设置解码器
tokenizer.decoder = decoders.ByteLevel()
# 检查特殊token的索引
assert tokenizer.token_to_id("<unk>") == 0
assert tokenizer.token_to_id("<s>") == 1
assert tokenizer.token_to_id("</s>") == 2
# 保存tokenizer
tokenizer_dir = "../model/minimind_tokenizer"
os.makedirs(tokenizer_dir, exist_ok=True)
tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
tokenizer.model.save("../model/minimind_tokenizer")
# 手动创建配置文件
config = {
"add_bos_token": False,
"add_eos_token": False,
"add_prefix_space": False,
"added_tokens_decoder": {
"0": {
"content": "<unk>",
"lstrip": False,
"normalized": False,
"rstrip": False,
"single_word": False,
"special": True
},
"1": {
"content": "<s>",
"lstrip": False,
"normalized": False,
"rstrip": False,
"single_word": False,
"special": True
},
"2": {
"content": "</s>",
"lstrip": False,
"normalized": False,
"rstrip": False,
"single_word": False,
"special": True
}
},
"additional_special_tokens": [],
"bos_token": "<s>",
"clean_up_tokenization_spaces": False,
"eos_token": "</s>",
"legacy": True,
"model_max_length": 32768,
"pad_token": "<unk>",
"sp_model_kwargs": {},
"spaces_between_special_tokens": False,
"tokenizer_class": "PreTrainedTokenizerFast",
"unk_token": "<unk>",
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<s>system\\n' + system_message + '</s>\\n' }}{% else %}{{ '<s>system\\n你是 MiniMind,是一个有用的人工智能助手。</s>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\n' + content + '</s>\\n<s>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\n' }}{% endif %}{% endfor %}"
}
# 保存配置文件
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w", encoding="utf-8") as config_file:
json.dump(config, config_file, ensure_ascii=False, indent=4)
print("Tokenizer training completed and saved.")
def eval_tokenizer():
from transformers import AutoTokenizer
# 加载预训练的tokenizer
tokenizer = AutoTokenizer.from_pretrained("../model/minimind_tokenizer")
messages = [
{"role": "system", "content": "你是一个优秀的聊天机器人,总是给我正确的回应!"},
{"role": "user", "content": '你来自哪里?'},
{"role": "assistant", "content": '我来自地球'}
]
new_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False
)
print(new_prompt)
# 获取实际词汇表长度(包括特殊符号)
actual_vocab_size = len(tokenizer)
print('tokenizer实际词表长度:', actual_vocab_size)
model_inputs = tokenizer(new_prompt)
print('encoder长度:', len(model_inputs['input_ids']))
input_ids = model_inputs['input_ids']
response = tokenizer.decode(input_ids, skip_special_tokens=True)
print('decoder和原始文本是否一致:', response == new_prompt)
def main():
# train_tokenizer()
eval_tokenizer()
if __name__ == '__main__':
main()
+293
View File
@@ -0,0 +1,293 @@
import random
import re
import time
import numpy as np
import streamlit as st
import torch
st.set_page_config(page_title="MiniMind", initial_sidebar_state="collapsed")
# 在文件开头的 CSS 样式中修改按钮样式
st.markdown("""
<style>
/* 添加操作按钮样式 */
.stButton button {
border-radius: 50% !important; /* 改为圆形 */
width: 32px !important; /* 固定宽度 */
height: 32px !important; /* 固定高度 */
padding: 0 !important; /* 移除内边距 */
background-color: transparent !important;
border: 1px solid #ddd !important;
display: flex !important;
align-items: center !important;
justify-content: center !important;
font-size: 14px !important;
color: #666 !important; /* 更柔和的颜色 */
margin: 5px 10px 5px 0 !important; /* 调整按钮间距 */
}
.stButton button:hover {
border-color: #999 !important;
color: #333 !important;
background-color: #f5f5f5 !important;
}
.stMainBlockContainer > div:first-child {
margin-top: -50px !important;
}
.stApp > div:last-child {
margin-bottom: -35px !important;
}
/* 重置按钮基础样式 */
.stButton > button {
all: unset !important; /* 重置所有默认样式 */
box-sizing: border-box !important;
border-radius: 50% !important;
width: 18px !important;
height: 18px !important;
min-width: 18px !important;
min-height: 18px !important;
max-width: 18px !important;
max-height: 18px !important;
padding: 0 !important;
background-color: transparent !important;
border: 1px solid #ddd !important;
display: flex !important;
align-items: center !important;
justify-content: center !important;
font-size: 14px !important;
color: #888 !important;
cursor: pointer !important;
transition: all 0.2s ease !important;
margin: 0 2px !important; /* 调整这里的 margin 值 */
}
</style>
""", unsafe_allow_html=True)
system_prompt = []
device = "cuda" if torch.cuda.is_available() else "cpu"
def process_assistant_content(content):
if 'R1' not in MODEL_PATHS[selected_model][1]:
return content
if '<think>' in content and '</think>' in content:
content = re.sub(r'(<think>)(.*?)(</think>)',
r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理内容(展开)</summary>\2</details>',
content,
flags=re.DOTALL)
if '<think>' in content and '</think>' not in content:
content = re.sub(r'<think>(.*?)$',
r'<details open style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理中...</summary>\1</details>',
content,
flags=re.DOTALL)
if '<think>' not in content and '</think>' in content:
content = re.sub(r'(.*?)</think>',
r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理内容(展开)</summary>\1</details>',
content,
flags=re.DOTALL)
return content
@st.cache_resource
def load_model_tokenizer(model_path):
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=False,
trust_remote_code=True
)
model = model.eval().to(device)
return model, tokenizer
def clear_chat_messages():
del st.session_state.messages
del st.session_state.chat_messages
def init_chat_messages():
if "messages" in st.session_state:
for i, message in enumerate(st.session_state.messages):
if message["role"] == "assistant":
with st.chat_message("assistant", avatar=image_url):
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
# 在消息内容下方添加按钮
if st.button("🗑", key=f"delete_{i}"):
st.session_state.messages.pop(i)
st.session_state.messages.pop(i - 1)
st.session_state.chat_messages.pop(i)
st.session_state.chat_messages.pop(i - 1)
st.rerun()
else:
st.markdown(
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: #ddd; border-radius: 10px; color: black;">{message["content"]}</div></div>',
unsafe_allow_html=True)
else:
st.session_state.messages = []
st.session_state.chat_messages = []
return st.session_state.messages
# 添加这两个辅助函数
def regenerate_answer(index):
st.session_state.messages.pop()
st.session_state.chat_messages.pop()
st.rerun()
def delete_conversation(index):
st.session_state.messages.pop(index)
st.session_state.messages.pop(index - 1)
st.session_state.chat_messages.pop(index)
st.session_state.chat_messages.pop(index - 1)
st.rerun()
# 侧边栏模型选择
st.sidebar.title("模型设定调整")
st.sidebar.text("【注】训练数据偏差,增加上下文记忆时\n多轮对话(较单轮)容易出现能力衰减")
st.session_state.history_chat_num = st.sidebar.slider("Number of Historical Dialogues", 0, 6, 0, step=2)
# st.session_state.history_chat_num = 0
st.session_state.max_new_tokens = st.sidebar.slider("Max Sequence Length", 256, 8192, 8192, step=1)
st.session_state.top_p = st.sidebar.slider("Top-P", 0.8, 0.99, 0.85, step=0.01)
st.session_state.temperature = st.sidebar.slider("Temperature", 0.6, 1.2, 0.85, step=0.01)
# 模型路径映射
MODEL_PATHS = {
"MiniMind2-Pro-R1 (0.1B)": ["../MiniMind2-Pro-R1", "MiniMind2-Pro-R1"],
"MiniMind2-R1 (0.05B)": ["../MiniMind2-R1", "MiniMind2-R1"],
"MiniMind2-Pro (0.1B)": ["../MiniMind2-Pro", "MiniMind2-Pro"],
"MiniMind2 (0.05B)": ["../MiniMind2", "MiniMind2"],
"MiniMind2-Small (0.02B)": ["../MiniMind2-Small", "MiniMind2-Small"],
"MiniMind-V1 (0.1B)": ["../minimind-v1", "MiniMind-V1"],
"MiniMind-V1-Small (0.02B)": ["../minimind-v1-small", "MiniMind-V1 Small"],
}
selected_model = st.sidebar.selectbox('Models', list(MODEL_PATHS.keys()), index=0) # 默认选择 MiniMind2
model_path = MODEL_PATHS[selected_model][0]
slogan = f"Hi, I'm {MODEL_PATHS[selected_model][1]}"
image_url = "https://www.modelscope.cn/api/v1/studio/gongjy/MiniMind/repo?Revision=master&FilePath=images%2Flogo2.png&View=true"
st.markdown(
f'<div style="display: flex; flex-direction: column; align-items: center; text-align: center; margin: 0; padding: 0;">'
'<div style="font-style: italic; font-weight: 900; margin: 0; padding-top: 4px; display: flex; align-items: center; justify-content: center; flex-wrap: wrap; width: 100%;">'
f'<img src="{image_url}" style="width: 45px; height: 45px; "> '
f'<span style="font-size: 26px; margin-left: 10px;">{slogan}</span>'
'</div>'
'<span style="color: #bbb; font-style: italic; margin-top: 6px; margin-bottom: 10px;">内容完全由AI生成,请务必仔细甄别<br>Content AI-generated, please discern with care</span>'
'</div>',
unsafe_allow_html=True
)
def setup_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main():
model, tokenizer = load_model_tokenizer(model_path)
# 初始化消息列表
if "messages" not in st.session_state:
st.session_state.messages = []
st.session_state.chat_messages = []
# Use session state messages
messages = st.session_state.messages
# 在显示历史消息的循环中
for i, message in enumerate(messages):
if message["role"] == "assistant":
with st.chat_message("assistant", avatar=image_url):
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
if st.button("×", key=f"delete_{i}"):
# 删除当前消息及其之后的所有消息
st.session_state.messages = st.session_state.messages[:i - 1]
st.session_state.chat_messages = st.session_state.chat_messages[:i - 1]
st.rerun()
else:
st.markdown(
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: gray; border-radius: 10px; color:white; ">{message["content"]}</div></div>',
unsafe_allow_html=True)
# 处理新的输入或重新生成
prompt = st.chat_input(key="input", placeholder="给 MiniMind 发送消息")
# 检查是否需要重新生成
if hasattr(st.session_state, 'regenerate') and st.session_state.regenerate:
prompt = st.session_state.last_user_message
regenerate_index = st.session_state.regenerate_index # 获取重新生成的位置
# 清除所有重新生成相关的状态
delattr(st.session_state, 'regenerate')
delattr(st.session_state, 'last_user_message')
delattr(st.session_state, 'regenerate_index')
if prompt:
st.markdown(
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: gray; border-radius: 10px; color:white; ">{prompt}</div></div>',
unsafe_allow_html=True)
messages.append({"role": "user", "content": prompt})
st.session_state.chat_messages.append({"role": "user", "content": prompt})
with st.chat_message("assistant", avatar=image_url):
placeholder = st.empty()
random_seed = random.randint(0, 2 ** 32 - 1)
setup_seed(random_seed)
st.session_state.chat_messages = system_prompt + st.session_state.chat_messages[
-(st.session_state.history_chat_num + 1):]
new_prompt = tokenizer.apply_chat_template(
st.session_state.chat_messages,
tokenize=False,
add_generation_prompt=True
)[-(st.session_state.max_new_tokens - 1):]
x = torch.tensor(tokenizer(new_prompt)['input_ids'], device=device).unsqueeze(0)
with torch.no_grad():
res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=st.session_state.max_new_tokens,
temperature=st.session_state.temperature,
top_p=st.session_state.top_p, stream=True)
try:
for y in res_y:
answer = tokenizer.decode(y[0].tolist(), skip_special_tokens=True)
if (answer and answer[-1] == '') or not answer:
continue
placeholder.markdown(process_assistant_content(answer), unsafe_allow_html=True)
except StopIteration:
print("No answer")
assistant_answer = answer.replace(new_prompt, "")
messages.append({"role": "assistant", "content": assistant_answer})
st.session_state.chat_messages.append({"role": "assistant", "content": assistant_answer})
with st.empty():
if st.button("×", key=f"delete_{len(messages) - 1}"):
st.session_state.messages = st.session_state.messages[:-2]
st.session_state.chat_messages = st.session_state.chat_messages[:-2]
st.rerun()
if __name__ == "__main__":
from transformers import AutoModelForCausalLM, AutoTokenizer
main()