mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
add minimind2
This commit is contained in:
@@ -0,0 +1,30 @@
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key="none",
|
||||
base_url="http://localhost:8998/v1"
|
||||
)
|
||||
stream = True
|
||||
conversation_history_origin = []
|
||||
conversation_history = conversation_history_origin.copy()
|
||||
while True:
|
||||
conversation_history = conversation_history_origin.copy()
|
||||
query = input('[Q]: ')
|
||||
conversation_history.append({"role": "user", "content": query})
|
||||
response = client.chat.completions.create(
|
||||
model="minimind",
|
||||
messages=conversation_history,
|
||||
stream=stream
|
||||
)
|
||||
if not stream:
|
||||
assistant_res = response.choices[0].message.content
|
||||
print('[A]: ', assistant_res)
|
||||
else:
|
||||
print('[A]: ', end='')
|
||||
assistant_res = ''
|
||||
for chunk in response:
|
||||
print(chunk.choices[0].delta.content or "", end="")
|
||||
assistant_res += chunk.choices[0].delta.content or ""
|
||||
|
||||
conversation_history.append({"role": "assistant", "content": assistant_res})
|
||||
print('\n\n')
|
||||
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
import warnings
|
||||
import sys
|
||||
import os
|
||||
|
||||
__package__ = "scripts"
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from model.LMConfig import LMConfig
|
||||
from model.model import MiniMindLM
|
||||
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
|
||||
def convert_torch2transformers(torch_path, transformers_path):
|
||||
def export_tokenizer(transformers_path):
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/minimind_tokenizer')
|
||||
tokenizer.save_pretrained(transformers_path)
|
||||
|
||||
LMConfig.register_for_auto_class()
|
||||
MiniMindLM.register_for_auto_class("AutoModelForCausalLM")
|
||||
lm_model = MiniMindLM(lm_config)
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
state_dict = torch.load(torch_path, map_location=device)
|
||||
lm_model.load_state_dict(state_dict, strict=False)
|
||||
model_params = sum(p.numel() for p in lm_model.parameters() if p.requires_grad)
|
||||
print(f'模型参数: {model_params / 1e6} 百万 = {model_params / 1e9} B (Billion)')
|
||||
lm_model.save_pretrained(transformers_path, safe_serialization=False)
|
||||
export_tokenizer(transformers_path)
|
||||
print(f"模型已保存为 Transformers 格式: {transformers_path}")
|
||||
|
||||
|
||||
def convert_transformers2torch(transformers_path, torch_path):
|
||||
model = AutoModelForCausalLM.from_pretrained(transformers_path, trust_remote_code=True)
|
||||
torch.save(model.state_dict(), torch_path)
|
||||
print(f"模型已保存为 PyTorch 格式: {torch_path}")
|
||||
|
||||
|
||||
# don't need to use
|
||||
def push_to_hf(export_model_path):
|
||||
def init_model():
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/minimind_tokenizer')
|
||||
model = AutoModelForCausalLM.from_pretrained(export_model_path, trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
|
||||
model, tokenizer = init_model()
|
||||
# model.push_to_hub(model_path)
|
||||
# tokenizer.push_to_hub(model_path, safe_serialization=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
lm_config = LMConfig(dim=512, n_layers=8, max_seq_len=8192, use_moe=False)
|
||||
|
||||
torch_path = f"../out/reason_{lm_config.dim}{'_moe' if lm_config.use_moe else ''}.pth"
|
||||
|
||||
transformers_path = '../MiniMind2-Small-R1'
|
||||
|
||||
# convert torch to transformers model
|
||||
convert_torch2transformers(torch_path, transformers_path)
|
||||
|
||||
# # convert transformers to torch model
|
||||
# convert_transformers2torch(transformers_path, torch_path)
|
||||
@@ -0,0 +1,185 @@
|
||||
import csv
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import jsonlines
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
bos_token = "<s>"
|
||||
eos_token = "</s>"
|
||||
|
||||
|
||||
def pretrain_process():
|
||||
# 定义输入和输出路径
|
||||
input_dir = '../CCI3-HQ/data'
|
||||
output_file = '../dataset/pretrain_data_hq.csv'
|
||||
jsonl_files = glob.glob(os.path.join(input_dir, 'part_*.jsonl'))
|
||||
total_lines = 0
|
||||
print("正在计算总行数...")
|
||||
for file in jsonl_files:
|
||||
with open(file, 'r', encoding='utf-8') as f:
|
||||
for _ in f:
|
||||
total_lines += 1
|
||||
with open(output_file, 'w', newline='', encoding='utf-8') as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
writer.writerow(['text', 'score']) # 写入表头
|
||||
for jsonl_file in jsonl_files:
|
||||
with open(jsonl_file, 'r', encoding='utf-8') as f:
|
||||
for line in tqdm(f, desc=f'处理 {os.path.basename(jsonl_file)}', total=total_lines, unit='行',
|
||||
leave=False):
|
||||
try:
|
||||
data = json.loads(line)
|
||||
text = data.get('text', '')
|
||||
score = data.get('score', 0)
|
||||
if len(text) <= 512 and score > 3.5:
|
||||
writer.writerow([text, score])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
print(f"筛选完成,结果已保存到 {output_file}")
|
||||
|
||||
|
||||
def sft_process():
|
||||
sft_file_name = 'sft_data.csv'
|
||||
|
||||
def process_and_write_data(data):
|
||||
q_lst, a_lst, history_lst = [], [], []
|
||||
for per in data:
|
||||
history, q, a = per['history'], per['q'], per['a']
|
||||
if not q or not a:
|
||||
continue
|
||||
history_len = sum(len(s) for s in history)
|
||||
message_len = history_len + len(q) + len(a)
|
||||
if message_len < 70 or message_len > 512:
|
||||
continue
|
||||
q_lst.append(q)
|
||||
a_lst.append(a)
|
||||
history_lst.append(history)
|
||||
|
||||
df = pd.DataFrame({'history': history_lst, 'q': q_lst, 'a': a_lst})
|
||||
df.to_csv(f'../dataset/{sft_file_name}',
|
||||
mode='a', header=False, index=False,
|
||||
lineterminator='\r\n', escapechar='\\', encoding='utf-8')
|
||||
|
||||
chunk_size = 1000
|
||||
data = []
|
||||
with open(f'../dataset/{sft_file_name}', 'w', encoding='utf-8') as f:
|
||||
f.write('history,q,a\n')
|
||||
|
||||
# sft_path = ['/root/shared-nvme/sft_data_zh.jsonl', '/root/shared-nvme/sft_data_en.jsonl']
|
||||
sft_path = ['/root/shared-nvme/sft_data_en.jsonl']
|
||||
chunk_num = 0
|
||||
for path in sft_path:
|
||||
with jsonlines.open(path) as reader:
|
||||
for idx, obj in enumerate(reader):
|
||||
try:
|
||||
data.append({
|
||||
'history': obj.get('history', ''),
|
||||
'q': obj.get('input', '') + obj.get('q', ''),
|
||||
'a': obj.get('output', '') + obj.get('a', '')
|
||||
})
|
||||
|
||||
if len(data) >= chunk_size:
|
||||
chunk_num += 1
|
||||
process_and_write_data(data)
|
||||
data = []
|
||||
if chunk_num % 100 == 0:
|
||||
print(f'chunk:{chunk_num} process end')
|
||||
except jsonlines.InvalidLineError as e:
|
||||
print(f"Skipping invalid JSON line {idx + 1}: {e}")
|
||||
continue
|
||||
|
||||
if data:
|
||||
process_and_write_data(data)
|
||||
data = []
|
||||
|
||||
|
||||
def rl_process():
|
||||
# 偏好数据默认只用中文(建议)
|
||||
input_paths = [
|
||||
# "../dataset/dpo_en.json",
|
||||
"../dataset/dpo_zh.json"
|
||||
]
|
||||
output_path = "../dataset/dpo_data.jsonl" # 修改输出文件扩展名为 .jsonl
|
||||
all_converted = []
|
||||
|
||||
for input_path in input_paths:
|
||||
with open(input_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f) # data is likely a list
|
||||
|
||||
for item in data:
|
||||
new_data = {
|
||||
"chosen": [],
|
||||
"rejected": []
|
||||
}
|
||||
for turn in item["conversations"]:
|
||||
role = "user" if turn["from"] == "human" else "assistant"
|
||||
message = {"role": role, "content": turn["value"]}
|
||||
new_data["chosen"].append(message)
|
||||
new_data["rejected"].append(message)
|
||||
new_data["chosen"].append({
|
||||
"role": "assistant",
|
||||
"content": item["chosen"]["value"]
|
||||
})
|
||||
new_data["rejected"].append({
|
||||
"role": "assistant",
|
||||
"content": item["rejected"]["value"]
|
||||
})
|
||||
all_converted.append(new_data)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
for item in all_converted:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
def lora_dataset():
|
||||
import json
|
||||
import csv
|
||||
|
||||
# 读取JSON文件
|
||||
with open('../dataset/Chinese-medical-dialogue.json', 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 准备CSV数据
|
||||
csv_data = []
|
||||
for item in data:
|
||||
# 提取input和output并去除首尾空白
|
||||
q = item['input'].strip()
|
||||
a = item['output'].strip()
|
||||
|
||||
# 检查长度是否符合要求
|
||||
if len(q) + len(a) < 160:
|
||||
csv_data.append({
|
||||
'history': '[]',
|
||||
'q': q,
|
||||
'a': a
|
||||
})
|
||||
|
||||
# 写入CSV文件
|
||||
with open('../dataset/medical_sft.csv', 'w', newline='', encoding='utf-8') as csvfile:
|
||||
fieldnames = ['history', 'q', 'a']
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
|
||||
writer.writeheader()
|
||||
writer.writerows(csv_data)
|
||||
|
||||
print(f'转换完成,共处理 {len(csv_data)} 条有效数据')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
################
|
||||
# 1: pretrain
|
||||
# 2: sft
|
||||
# 3: RL
|
||||
################
|
||||
process_type = 4
|
||||
|
||||
if process_type == 1:
|
||||
pretrain_process()
|
||||
if process_type == 2:
|
||||
sft_process()
|
||||
if process_type == 3:
|
||||
rl_process()
|
||||
if process_type == 4:
|
||||
lora_dataset()
|
||||
@@ -0,0 +1,97 @@
|
||||
# from datasets import load_dataset
|
||||
#
|
||||
# dataset_paths = [
|
||||
# ['ceval/ceval-exam',
|
||||
# ['computer_network', 'operating_system', 'computer_architecture', 'college_programming', 'college_physics',
|
||||
# 'college_chemistry', 'advanced_mathematics', 'probability_and_statistics', 'discrete_mathematics',
|
||||
# 'electrical_engineer', 'metrology_engineer', 'high_school_mathematics', 'high_school_physics',
|
||||
# 'high_school_chemistry', 'high_school_biology', 'middle_school_mathematics', 'middle_school_biology',
|
||||
# 'middle_school_physics', 'middle_school_chemistry', 'veterinary_medicine', 'college_economics',
|
||||
# 'business_administration', 'marxism', 'mao_zedong_thought', 'education_science', 'teacher_qualification',
|
||||
# 'high_school_politics', 'high_school_geography', 'middle_school_politics', 'middle_school_geography',
|
||||
# 'modern_chinese_history', 'ideological_and_moral_cultivation', 'logic', 'law', 'chinese_language_and_literature',
|
||||
# 'art_studies', 'professional_tour_guide', 'legal_professional', 'high_school_chinese', 'high_school_history',
|
||||
# 'middle_school_history', 'civil_servant', 'sports_science', 'plant_protection', 'basic_medicine',
|
||||
# 'clinical_medicine', 'urban_and_rural_planner', 'accountant', 'fire_engineer',
|
||||
# 'environmental_impact_assessment_engineer', 'tax_accountant', 'physician']], # ceval*
|
||||
# ['haonan-li/cmmlu', [
|
||||
# 'agronomy', 'anatomy', 'ancient_chinese', 'arts', 'astronomy', 'business_ethics',
|
||||
# 'chinese_civil_service_exam', 'chinese_driving_rule', 'chinese_food_culture',
|
||||
# 'chinese_foreign_policy', 'chinese_history', 'chinese_literature',
|
||||
# 'chinese_teacher_qualification', 'clinical_knowledge', 'college_actuarial_science',
|
||||
# 'college_education', 'college_engineering_hydrology', 'college_law',
|
||||
# 'college_mathematics', 'college_medical_statistics', 'college_medicine',
|
||||
# 'computer_science', 'computer_security', 'conceptual_physics',
|
||||
# 'construction_project_management', 'economics', 'education', 'electrical_engineering',
|
||||
# 'elementary_chinese', 'elementary_commonsense', 'elementary_information_and_technology',
|
||||
# 'elementary_mathematics', 'ethnology', 'food_science', 'genetics', 'global_facts',
|
||||
# 'high_school_biology', 'high_school_chemistry', 'high_school_geography',
|
||||
# 'high_school_mathematics', 'high_school_physics', 'high_school_politics',
|
||||
# 'human_sexuality', 'international_law', 'journalism', 'jurisprudence',
|
||||
# 'legal_and_moral_basis', 'logical', 'machine_learning', 'management', 'marketing',
|
||||
# 'marxist_theory', 'modern_chinese', 'nutrition', 'philosophy', 'professional_accounting',
|
||||
# 'professional_law', 'professional_medicine', 'professional_psychology',
|
||||
# 'public_relations', 'security_study', 'sociology', 'sports_science',
|
||||
# 'traditional_chinese_medicine', 'virology', 'world_history', 'world_religions'
|
||||
# ]], # cmmlu*
|
||||
# ['tyouisen/aclue',
|
||||
# ['polysemy_resolution', 'poetry_sentiment_analysis', 'named_entity_recognition', 'basic_ancient_chinese',
|
||||
# 'poetry_context_prediction', 'sentence_segmentation', 'couplet_prediction', 'poetry_appreciate',
|
||||
# 'ancient_chinese_culture', 'ancient_phonetics', 'homographic_character_resolution', 'ancient_literature',
|
||||
# 'ancient_medical', 'poetry_quality_assessment', 'reading_comprehension']], # aclue
|
||||
# ['juletxara/mgsm', ['zh']], # mgsm_direct_zh
|
||||
# ['openbookqa', ['main']], # openbookqa
|
||||
# ['ZoneTwelve/tmmluplus',
|
||||
# ['dentistry', 'traditional_chinese_medicine_clinical_medicine', 'clinical_psychology', 'technical',
|
||||
# 'culinary_skills', 'mechanical', 'logic_reasoning', 'real_estate', 'general_principles_of_law', 'finance_banking',
|
||||
# 'anti_money_laundering', 'ttqav2', 'marketing_management', 'business_management', 'organic_chemistry',
|
||||
# 'advance_chemistry', 'physics', 'secondary_physics', 'human_behavior', 'national_protection', 'jce_humanities',
|
||||
# 'politic_science', 'agriculture', 'official_document_management', 'financial_analysis', 'pharmacy',
|
||||
# 'educational_psychology', 'statistics_and_machine_learning', 'management_accounting', 'introduction_to_law',
|
||||
# 'computer_science', 'veterinary_pathology', 'accounting', 'fire_science', 'optometry', 'insurance_studies',
|
||||
# 'pharmacology', 'taxation', 'education_(profession_level)', 'economics', 'veterinary_pharmacology',
|
||||
# 'nautical_science', 'occupational_therapy_for_psychological_disorders', 'trust_practice', 'geography_of_taiwan',
|
||||
# 'physical_education', 'auditing', 'administrative_law', 'basic_medical_science', 'macroeconomics', 'trade',
|
||||
# 'chinese_language_and_literature', 'tve_design', 'junior_science_exam', 'junior_math_exam', 'junior_chinese_exam',
|
||||
# 'junior_social_studies', 'tve_mathematics', 'tve_chinese_language', 'tve_natural_sciences', 'junior_chemistry',
|
||||
# 'music', 'education', 'three_principles_of_people', 'taiwanese_hokkien', 'engineering_math', 'linear_algebra']]
|
||||
# # tmmluplus
|
||||
#
|
||||
# ]
|
||||
#
|
||||
# for dataset_path in dataset_paths:
|
||||
# for dataset_name in dataset_path[1]:
|
||||
# datasets = load_dataset(dataset_path[0], dataset_name, cache_dir='./test_dataset_cache')
|
||||
#
|
||||
# """
|
||||
# export HF_HUB_OFFLINE=1 && lm_eval --model hf --model_args pretrained=/xxx/minimind/minimind-v2-small/,device=cuda,dtype=auto --tasks ceval* --batch_size 8 --trust_remote_code
|
||||
# """
|
||||
"""
|
||||
$env:HF_HUB_OFFLINE=1; lm_eval --model hf --model_args pretrained=../minimind-v2-small/,device=cuda,dtype=auto --tasks ceval* --batch_size 8 --trust_remote_code
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
|
||||
# 定义要执行的命令
|
||||
command = (
|
||||
'set HF_HUB_OFFLINE=1 & '
|
||||
'lm_eval --model hf --model_args pretrained=../minimind-v2-small/,device=cuda,dtype=auto '
|
||||
'--tasks ceval* --batch_size 8 --trust_remote_code'
|
||||
)
|
||||
|
||||
# 使用 subprocess 执行命令
|
||||
try:
|
||||
process = subprocess.run(
|
||||
command,
|
||||
shell=True,
|
||||
check=True,
|
||||
text=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
# 打印命令的输出
|
||||
print("STDOUT:", process.stdout)
|
||||
print("STDERR:", process.stderr)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"命令执行失败,返回码: {e.returncode}")
|
||||
print("STDERR:", e.stderr)
|
||||
@@ -0,0 +1,164 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
__package__ = "scripts"
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
import time
|
||||
import torch
|
||||
import warnings
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from model.LMConfig import LMConfig
|
||||
from model.model import MiniMindLM
|
||||
from model.model_lora import apply_lora, load_lora
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def init_model(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/minimind_tokenizer')
|
||||
if args.load == 0:
|
||||
moe_path = '_moe' if args.use_moe else ''
|
||||
modes = {0: 'pretrain', 1: 'full_sft', 2: 'full_dist', 3: 'rlhf'}
|
||||
ckp = f'../{args.out_dir}/{modes[args.model_mode]}_{args.dim}{moe_path}.pth'
|
||||
|
||||
model = MiniMindLM(LMConfig(
|
||||
dim=args.dim,
|
||||
n_layers=args.n_layers,
|
||||
max_seq_len=args.max_seq_len,
|
||||
use_moe=args.use_moe
|
||||
))
|
||||
|
||||
state_dict = torch.load(ckp, map_location=device)
|
||||
model.load_state_dict({k: v for k, v in state_dict.items() if 'mask' not in k}, strict=True)
|
||||
|
||||
if args.lora_name != 'None':
|
||||
apply_lora(model)
|
||||
load_lora(model, f'../{args.out_dir}/{args.lora_name}_{args.dim}.pth')
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
'./MiniMind2',
|
||||
trust_remote_code=True
|
||||
)
|
||||
print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M(illion)')
|
||||
return model.eval().to(device), tokenizer
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
model: str
|
||||
messages: list
|
||||
temperature: float = 0.7
|
||||
top_p: int = 0.92
|
||||
max_tokens: int = 8192
|
||||
stream: bool = False
|
||||
|
||||
|
||||
def generate_stream_response(messages, temperature, top_p, max_tokens):
|
||||
try:
|
||||
new_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)[-max_tokens:]
|
||||
x = tokenizer(new_prompt).data['input_ids']
|
||||
x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
|
||||
with torch.no_grad():
|
||||
res_y = model.generate(
|
||||
x,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
stream=True,
|
||||
rp=1.,
|
||||
pad_token_id=tokenizer.pad_token_id
|
||||
)
|
||||
history_idx = 0
|
||||
for y in res_y:
|
||||
answer = tokenizer.decode(y[0].tolist(), skip_special_tokens=True)
|
||||
if (answer and answer[-1] == '�') or not answer:
|
||||
continue
|
||||
delta = answer[history_idx:]
|
||||
history_idx = len(answer)
|
||||
json_data = {
|
||||
'id': f'chatcmpl-{int(time.time())}',
|
||||
'object': 'chat.completion.chunk',
|
||||
'created': int(time.time()),
|
||||
'model': 'minimind',
|
||||
'choices': [{'index': 0, 'delta': {'content': delta}, 'finish_reason': None}]
|
||||
}
|
||||
yield f"data: {json.dumps(json_data)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: ChatRequest):
|
||||
try:
|
||||
if request.stream:
|
||||
return StreamingResponse(
|
||||
generate_stream_response(
|
||||
messages=request.messages,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_tokens=request.max_tokens
|
||||
),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
new_prompt = tokenizer.apply_chat_template(
|
||||
request.messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)[-request.max_tokens:]
|
||||
x = tokenizer(new_prompt).data['input_ids']
|
||||
x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
|
||||
with torch.no_grad():
|
||||
res_y = model.generate(
|
||||
x,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
max_new_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
stream=False,
|
||||
rp=1.,
|
||||
pad_token_id=tokenizer.pad_token_id
|
||||
)
|
||||
answer = tokenizer.decode(res_y.squeeze()[x.shape[1]:].tolist(), skip_special_tokens=True)
|
||||
return {
|
||||
"id": f"chatcmpl-{int(time.time())}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": "minimind",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": answer},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Server for MiniMind")
|
||||
parser.add_argument('--out_dir', default='out', type=str)
|
||||
parser.add_argument('--lora_name', default='None', type=str)
|
||||
parser.add_argument('--dim', default=512, type=int)
|
||||
parser.add_argument('--n_layers', default=8, type=int)
|
||||
parser.add_argument('--max_seq_len', default=8192, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument('--load', default=0, type=int, help="0: 从原生torch权重,1: 利用transformers加载")
|
||||
parser.add_argument('--model_mode', default=1, type=int, help="0: 预训练模型,1: SFT-Chat模型,2: RLHF-Chat模型")
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model, tokenizer = init_model(parser.parse_args())
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8998)
|
||||
@@ -0,0 +1,152 @@
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
import json
|
||||
from datasets import load_dataset
|
||||
from tokenizers import (
|
||||
decoders,
|
||||
models,
|
||||
normalizers,
|
||||
pre_tokenizers,
|
||||
processors,
|
||||
trainers,
|
||||
Tokenizer,
|
||||
)
|
||||
import os
|
||||
|
||||
random.seed(42)
|
||||
|
||||
|
||||
def train_tokenizer():
|
||||
# 读取JSONL文件并提取文本数据
|
||||
def read_texts_from_jsonl(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
yield data['text']
|
||||
|
||||
data_path = '../dataset/tokenizer_train.jsonl'
|
||||
|
||||
# 初始化tokenizer
|
||||
tokenizer = Tokenizer(models.BPE())
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
||||
|
||||
# 定义特殊token
|
||||
special_tokens = ["<unk>", "<s>", "</s>"]
|
||||
|
||||
# 设置训练器并添加特殊token
|
||||
trainer = trainers.BpeTrainer(
|
||||
vocab_size=6400,
|
||||
special_tokens=special_tokens, # 确保这三个token被包含
|
||||
show_progress=True,
|
||||
initial_alphabet=pre_tokenizers.ByteLevel.alphabet()
|
||||
)
|
||||
|
||||
# 读取文本数据
|
||||
texts = read_texts_from_jsonl(data_path)
|
||||
|
||||
# 训练tokenizer
|
||||
tokenizer.train_from_iterator(texts, trainer=trainer)
|
||||
|
||||
# 设置解码器
|
||||
tokenizer.decoder = decoders.ByteLevel()
|
||||
|
||||
# 检查特殊token的索引
|
||||
assert tokenizer.token_to_id("<unk>") == 0
|
||||
assert tokenizer.token_to_id("<s>") == 1
|
||||
assert tokenizer.token_to_id("</s>") == 2
|
||||
|
||||
# 保存tokenizer
|
||||
tokenizer_dir = "../model/minimind_tokenizer"
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||
tokenizer.model.save("../model/minimind_tokenizer")
|
||||
|
||||
# 手动创建配置文件
|
||||
config = {
|
||||
"add_bos_token": False,
|
||||
"add_eos_token": False,
|
||||
"add_prefix_space": False,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<unk>",
|
||||
"lstrip": False,
|
||||
"normalized": False,
|
||||
"rstrip": False,
|
||||
"single_word": False,
|
||||
"special": True
|
||||
},
|
||||
"1": {
|
||||
"content": "<s>",
|
||||
"lstrip": False,
|
||||
"normalized": False,
|
||||
"rstrip": False,
|
||||
"single_word": False,
|
||||
"special": True
|
||||
},
|
||||
"2": {
|
||||
"content": "</s>",
|
||||
"lstrip": False,
|
||||
"normalized": False,
|
||||
"rstrip": False,
|
||||
"single_word": False,
|
||||
"special": True
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [],
|
||||
"bos_token": "<s>",
|
||||
"clean_up_tokenization_spaces": False,
|
||||
"eos_token": "</s>",
|
||||
"legacy": True,
|
||||
"model_max_length": 32768,
|
||||
"pad_token": "<unk>",
|
||||
"sp_model_kwargs": {},
|
||||
"spaces_between_special_tokens": False,
|
||||
"tokenizer_class": "PreTrainedTokenizerFast",
|
||||
"unk_token": "<unk>",
|
||||
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<s>system\\n' + system_message + '</s>\\n' }}{% else %}{{ '<s>system\\n你是 MiniMind,是一个有用的人工智能助手。</s>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\n' + content + '</s>\\n<s>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\n' }}{% endif %}{% endfor %}"
|
||||
}
|
||||
|
||||
# 保存配置文件
|
||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w", encoding="utf-8") as config_file:
|
||||
json.dump(config, config_file, ensure_ascii=False, indent=4)
|
||||
|
||||
print("Tokenizer training completed and saved.")
|
||||
|
||||
|
||||
def eval_tokenizer():
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# 加载预训练的tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("../model/minimind_tokenizer")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个优秀的聊天机器人,总是给我正确的回应!"},
|
||||
{"role": "user", "content": '你来自哪里?'},
|
||||
{"role": "assistant", "content": '我来自地球'}
|
||||
]
|
||||
new_prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False
|
||||
)
|
||||
print(new_prompt)
|
||||
|
||||
# 获取实际词汇表长度(包括特殊符号)
|
||||
actual_vocab_size = len(tokenizer)
|
||||
print('tokenizer实际词表长度:', actual_vocab_size)
|
||||
|
||||
model_inputs = tokenizer(new_prompt)
|
||||
print('encoder长度:', len(model_inputs['input_ids']))
|
||||
|
||||
input_ids = model_inputs['input_ids']
|
||||
response = tokenizer.decode(input_ids, skip_special_tokens=True)
|
||||
print('decoder和原始文本是否一致:', response == new_prompt)
|
||||
|
||||
|
||||
def main():
|
||||
# train_tokenizer()
|
||||
eval_tokenizer()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,293 @@
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import torch
|
||||
|
||||
st.set_page_config(page_title="MiniMind", initial_sidebar_state="collapsed")
|
||||
|
||||
# 在文件开头的 CSS 样式中修改按钮样式
|
||||
st.markdown("""
|
||||
<style>
|
||||
/* 添加操作按钮样式 */
|
||||
.stButton button {
|
||||
border-radius: 50% !important; /* 改为圆形 */
|
||||
width: 32px !important; /* 固定宽度 */
|
||||
height: 32px !important; /* 固定高度 */
|
||||
padding: 0 !important; /* 移除内边距 */
|
||||
background-color: transparent !important;
|
||||
border: 1px solid #ddd !important;
|
||||
display: flex !important;
|
||||
align-items: center !important;
|
||||
justify-content: center !important;
|
||||
font-size: 14px !important;
|
||||
color: #666 !important; /* 更柔和的颜色 */
|
||||
margin: 5px 10px 5px 0 !important; /* 调整按钮间距 */
|
||||
}
|
||||
.stButton button:hover {
|
||||
border-color: #999 !important;
|
||||
color: #333 !important;
|
||||
background-color: #f5f5f5 !important;
|
||||
}
|
||||
.stMainBlockContainer > div:first-child {
|
||||
margin-top: -50px !important;
|
||||
}
|
||||
.stApp > div:last-child {
|
||||
margin-bottom: -35px !important;
|
||||
}
|
||||
|
||||
/* 重置按钮基础样式 */
|
||||
.stButton > button {
|
||||
all: unset !important; /* 重置所有默认样式 */
|
||||
box-sizing: border-box !important;
|
||||
border-radius: 50% !important;
|
||||
width: 18px !important;
|
||||
height: 18px !important;
|
||||
min-width: 18px !important;
|
||||
min-height: 18px !important;
|
||||
max-width: 18px !important;
|
||||
max-height: 18px !important;
|
||||
padding: 0 !important;
|
||||
background-color: transparent !important;
|
||||
border: 1px solid #ddd !important;
|
||||
display: flex !important;
|
||||
align-items: center !important;
|
||||
justify-content: center !important;
|
||||
font-size: 14px !important;
|
||||
color: #888 !important;
|
||||
cursor: pointer !important;
|
||||
transition: all 0.2s ease !important;
|
||||
margin: 0 2px !important; /* 调整这里的 margin 值 */
|
||||
}
|
||||
|
||||
</style>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
system_prompt = []
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def process_assistant_content(content):
|
||||
if 'R1' not in MODEL_PATHS[selected_model][1]:
|
||||
return content
|
||||
|
||||
if '<think>' in content and '</think>' in content:
|
||||
content = re.sub(r'(<think>)(.*?)(</think>)',
|
||||
r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理内容(展开)</summary>\2</details>',
|
||||
content,
|
||||
flags=re.DOTALL)
|
||||
|
||||
if '<think>' in content and '</think>' not in content:
|
||||
content = re.sub(r'<think>(.*?)$',
|
||||
r'<details open style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理中...</summary>\1</details>',
|
||||
content,
|
||||
flags=re.DOTALL)
|
||||
|
||||
if '<think>' not in content and '</think>' in content:
|
||||
content = re.sub(r'(.*?)</think>',
|
||||
r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理内容(展开)</summary>\1</details>',
|
||||
content,
|
||||
flags=re.DOTALL)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def load_model_tokenizer(model_path):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
trust_remote_code=True
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
use_fast=False,
|
||||
trust_remote_code=True
|
||||
)
|
||||
model = model.eval().to(device)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def clear_chat_messages():
|
||||
del st.session_state.messages
|
||||
del st.session_state.chat_messages
|
||||
|
||||
|
||||
def init_chat_messages():
|
||||
if "messages" in st.session_state:
|
||||
for i, message in enumerate(st.session_state.messages):
|
||||
if message["role"] == "assistant":
|
||||
with st.chat_message("assistant", avatar=image_url):
|
||||
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
|
||||
# 在消息内容下方添加按钮
|
||||
if st.button("🗑", key=f"delete_{i}"):
|
||||
st.session_state.messages.pop(i)
|
||||
st.session_state.messages.pop(i - 1)
|
||||
st.session_state.chat_messages.pop(i)
|
||||
st.session_state.chat_messages.pop(i - 1)
|
||||
st.rerun()
|
||||
else:
|
||||
st.markdown(
|
||||
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: #ddd; border-radius: 10px; color: black;">{message["content"]}</div></div>',
|
||||
unsafe_allow_html=True)
|
||||
|
||||
else:
|
||||
st.session_state.messages = []
|
||||
st.session_state.chat_messages = []
|
||||
|
||||
return st.session_state.messages
|
||||
|
||||
|
||||
# 添加这两个辅助函数
|
||||
def regenerate_answer(index):
|
||||
st.session_state.messages.pop()
|
||||
st.session_state.chat_messages.pop()
|
||||
st.rerun()
|
||||
|
||||
|
||||
def delete_conversation(index):
|
||||
st.session_state.messages.pop(index)
|
||||
st.session_state.messages.pop(index - 1)
|
||||
st.session_state.chat_messages.pop(index)
|
||||
st.session_state.chat_messages.pop(index - 1)
|
||||
st.rerun()
|
||||
|
||||
|
||||
# 侧边栏模型选择
|
||||
st.sidebar.title("模型设定调整")
|
||||
|
||||
st.sidebar.text("【注】训练数据偏差,增加上下文记忆时\n多轮对话(较单轮)容易出现能力衰减")
|
||||
st.session_state.history_chat_num = st.sidebar.slider("Number of Historical Dialogues", 0, 6, 0, step=2)
|
||||
# st.session_state.history_chat_num = 0
|
||||
st.session_state.max_new_tokens = st.sidebar.slider("Max Sequence Length", 256, 8192, 8192, step=1)
|
||||
st.session_state.top_p = st.sidebar.slider("Top-P", 0.8, 0.99, 0.85, step=0.01)
|
||||
st.session_state.temperature = st.sidebar.slider("Temperature", 0.6, 1.2, 0.85, step=0.01)
|
||||
|
||||
# 模型路径映射
|
||||
MODEL_PATHS = {
|
||||
"MiniMind2-Pro-R1 (0.1B)": ["../MiniMind2-Pro-R1", "MiniMind2-Pro-R1"],
|
||||
"MiniMind2-R1 (0.05B)": ["../MiniMind2-R1", "MiniMind2-R1"],
|
||||
"MiniMind2-Pro (0.1B)": ["../MiniMind2-Pro", "MiniMind2-Pro"],
|
||||
"MiniMind2 (0.05B)": ["../MiniMind2", "MiniMind2"],
|
||||
"MiniMind2-Small (0.02B)": ["../MiniMind2-Small", "MiniMind2-Small"],
|
||||
"MiniMind-V1 (0.1B)": ["../minimind-v1", "MiniMind-V1"],
|
||||
"MiniMind-V1-Small (0.02B)": ["../minimind-v1-small", "MiniMind-V1 Small"],
|
||||
}
|
||||
|
||||
selected_model = st.sidebar.selectbox('Models', list(MODEL_PATHS.keys()), index=0) # 默认选择 MiniMind2
|
||||
model_path = MODEL_PATHS[selected_model][0]
|
||||
|
||||
slogan = f"Hi, I'm {MODEL_PATHS[selected_model][1]}"
|
||||
|
||||
image_url = "https://www.modelscope.cn/api/v1/studio/gongjy/MiniMind/repo?Revision=master&FilePath=images%2Flogo2.png&View=true"
|
||||
|
||||
st.markdown(
|
||||
f'<div style="display: flex; flex-direction: column; align-items: center; text-align: center; margin: 0; padding: 0;">'
|
||||
'<div style="font-style: italic; font-weight: 900; margin: 0; padding-top: 4px; display: flex; align-items: center; justify-content: center; flex-wrap: wrap; width: 100%;">'
|
||||
f'<img src="{image_url}" style="width: 45px; height: 45px; "> '
|
||||
f'<span style="font-size: 26px; margin-left: 10px;">{slogan}</span>'
|
||||
'</div>'
|
||||
'<span style="color: #bbb; font-style: italic; margin-top: 6px; margin-bottom: 10px;">内容完全由AI生成,请务必仔细甄别<br>Content AI-generated, please discern with care</span>'
|
||||
'</div>',
|
||||
unsafe_allow_html=True
|
||||
)
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def main():
|
||||
model, tokenizer = load_model_tokenizer(model_path)
|
||||
|
||||
# 初始化消息列表
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
st.session_state.chat_messages = []
|
||||
|
||||
# Use session state messages
|
||||
messages = st.session_state.messages
|
||||
|
||||
# 在显示历史消息的循环中
|
||||
for i, message in enumerate(messages):
|
||||
if message["role"] == "assistant":
|
||||
with st.chat_message("assistant", avatar=image_url):
|
||||
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
|
||||
if st.button("×", key=f"delete_{i}"):
|
||||
# 删除当前消息及其之后的所有消息
|
||||
st.session_state.messages = st.session_state.messages[:i - 1]
|
||||
st.session_state.chat_messages = st.session_state.chat_messages[:i - 1]
|
||||
st.rerun()
|
||||
else:
|
||||
st.markdown(
|
||||
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: gray; border-radius: 10px; color:white; ">{message["content"]}</div></div>',
|
||||
unsafe_allow_html=True)
|
||||
|
||||
# 处理新的输入或重新生成
|
||||
prompt = st.chat_input(key="input", placeholder="给 MiniMind 发送消息")
|
||||
|
||||
# 检查是否需要重新生成
|
||||
if hasattr(st.session_state, 'regenerate') and st.session_state.regenerate:
|
||||
prompt = st.session_state.last_user_message
|
||||
regenerate_index = st.session_state.regenerate_index # 获取重新生成的位置
|
||||
# 清除所有重新生成相关的状态
|
||||
delattr(st.session_state, 'regenerate')
|
||||
delattr(st.session_state, 'last_user_message')
|
||||
delattr(st.session_state, 'regenerate_index')
|
||||
|
||||
if prompt:
|
||||
st.markdown(
|
||||
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: gray; border-radius: 10px; color:white; ">{prompt}</div></div>',
|
||||
unsafe_allow_html=True)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
st.session_state.chat_messages.append({"role": "user", "content": prompt})
|
||||
|
||||
with st.chat_message("assistant", avatar=image_url):
|
||||
placeholder = st.empty()
|
||||
random_seed = random.randint(0, 2 ** 32 - 1)
|
||||
setup_seed(random_seed)
|
||||
|
||||
st.session_state.chat_messages = system_prompt + st.session_state.chat_messages[
|
||||
-(st.session_state.history_chat_num + 1):]
|
||||
new_prompt = tokenizer.apply_chat_template(
|
||||
st.session_state.chat_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)[-(st.session_state.max_new_tokens - 1):]
|
||||
|
||||
x = torch.tensor(tokenizer(new_prompt)['input_ids'], device=device).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=st.session_state.max_new_tokens,
|
||||
temperature=st.session_state.temperature,
|
||||
top_p=st.session_state.top_p, stream=True)
|
||||
try:
|
||||
for y in res_y:
|
||||
answer = tokenizer.decode(y[0].tolist(), skip_special_tokens=True)
|
||||
if (answer and answer[-1] == '�') or not answer:
|
||||
continue
|
||||
placeholder.markdown(process_assistant_content(answer), unsafe_allow_html=True)
|
||||
except StopIteration:
|
||||
print("No answer")
|
||||
|
||||
assistant_answer = answer.replace(new_prompt, "")
|
||||
messages.append({"role": "assistant", "content": assistant_answer})
|
||||
st.session_state.chat_messages.append({"role": "assistant", "content": assistant_answer})
|
||||
|
||||
with st.empty():
|
||||
if st.button("×", key=f"delete_{len(messages) - 1}"):
|
||||
st.session_state.messages = st.session_state.messages[:-2]
|
||||
st.session_state.chat_messages = st.session_state.chat_messages[:-2]
|
||||
st.rerun()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
main()
|
||||
Reference in New Issue
Block a user