mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
update scripts file
This commit is contained in:
@@ -1,185 +0,0 @@
|
||||
import csv
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import jsonlines
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
bos_token = "<s>"
|
||||
eos_token = "</s>"
|
||||
|
||||
|
||||
def pretrain_process():
|
||||
# 定义输入和输出路径
|
||||
input_dir = '../CCI3-HQ/data'
|
||||
output_file = '../dataset/pretrain_data_hq.csv'
|
||||
jsonl_files = glob.glob(os.path.join(input_dir, 'part_*.jsonl'))
|
||||
total_lines = 0
|
||||
print("正在计算总行数...")
|
||||
for file in jsonl_files:
|
||||
with open(file, 'r', encoding='utf-8') as f:
|
||||
for _ in f:
|
||||
total_lines += 1
|
||||
with open(output_file, 'w', newline='', encoding='utf-8') as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
writer.writerow(['text', 'score']) # 写入表头
|
||||
for jsonl_file in jsonl_files:
|
||||
with open(jsonl_file, 'r', encoding='utf-8') as f:
|
||||
for line in tqdm(f, desc=f'处理 {os.path.basename(jsonl_file)}', total=total_lines, unit='行',
|
||||
leave=False):
|
||||
try:
|
||||
data = json.loads(line)
|
||||
text = data.get('text', '')
|
||||
score = data.get('score', 0)
|
||||
if len(text) <= 512 and score > 3.5:
|
||||
writer.writerow([text, score])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
print(f"筛选完成,结果已保存到 {output_file}")
|
||||
|
||||
|
||||
def sft_process():
|
||||
sft_file_name = 'sft_data.csv'
|
||||
|
||||
def process_and_write_data(data):
|
||||
q_lst, a_lst, history_lst = [], [], []
|
||||
for per in data:
|
||||
history, q, a = per['history'], per['q'], per['a']
|
||||
if not q or not a:
|
||||
continue
|
||||
history_len = sum(len(s) for s in history)
|
||||
message_len = history_len + len(q) + len(a)
|
||||
if message_len < 70 or message_len > 512:
|
||||
continue
|
||||
q_lst.append(q)
|
||||
a_lst.append(a)
|
||||
history_lst.append(history)
|
||||
|
||||
df = pd.DataFrame({'history': history_lst, 'q': q_lst, 'a': a_lst})
|
||||
df.to_csv(f'../dataset/{sft_file_name}',
|
||||
mode='a', header=False, index=False,
|
||||
lineterminator='\r\n', escapechar='\\', encoding='utf-8')
|
||||
|
||||
chunk_size = 1000
|
||||
data = []
|
||||
with open(f'../dataset/{sft_file_name}', 'w', encoding='utf-8') as f:
|
||||
f.write('history,q,a\n')
|
||||
|
||||
# sft_path = ['/root/shared-nvme/sft_data_zh.jsonl', '/root/shared-nvme/sft_data_en.jsonl']
|
||||
sft_path = ['/root/shared-nvme/sft_data_en.jsonl']
|
||||
chunk_num = 0
|
||||
for path in sft_path:
|
||||
with jsonlines.open(path) as reader:
|
||||
for idx, obj in enumerate(reader):
|
||||
try:
|
||||
data.append({
|
||||
'history': obj.get('history', ''),
|
||||
'q': obj.get('input', '') + obj.get('q', ''),
|
||||
'a': obj.get('output', '') + obj.get('a', '')
|
||||
})
|
||||
|
||||
if len(data) >= chunk_size:
|
||||
chunk_num += 1
|
||||
process_and_write_data(data)
|
||||
data = []
|
||||
if chunk_num % 100 == 0:
|
||||
print(f'chunk:{chunk_num} process end')
|
||||
except jsonlines.InvalidLineError as e:
|
||||
print(f"Skipping invalid JSON line {idx + 1}: {e}")
|
||||
continue
|
||||
|
||||
if data:
|
||||
process_and_write_data(data)
|
||||
data = []
|
||||
|
||||
|
||||
def rl_process():
|
||||
# 偏好数据默认只用中文(建议)
|
||||
input_paths = [
|
||||
# "../dataset/dpo_en.json",
|
||||
"../dataset/dpo_zh.json"
|
||||
]
|
||||
output_path = "../dataset/dpo_data.jsonl" # 修改输出文件扩展名为 .jsonl
|
||||
all_converted = []
|
||||
|
||||
for input_path in input_paths:
|
||||
with open(input_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f) # data is likely a list
|
||||
|
||||
for item in data:
|
||||
new_data = {
|
||||
"chosen": [],
|
||||
"rejected": []
|
||||
}
|
||||
for turn in item["conversations"]:
|
||||
role = "user" if turn["from"] == "human" else "assistant"
|
||||
message = {"role": role, "content": turn["value"]}
|
||||
new_data["chosen"].append(message)
|
||||
new_data["rejected"].append(message)
|
||||
new_data["chosen"].append({
|
||||
"role": "assistant",
|
||||
"content": item["chosen"]["value"]
|
||||
})
|
||||
new_data["rejected"].append({
|
||||
"role": "assistant",
|
||||
"content": item["rejected"]["value"]
|
||||
})
|
||||
all_converted.append(new_data)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
for item in all_converted:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
def lora_dataset():
|
||||
import json
|
||||
import csv
|
||||
|
||||
# 读取JSON文件
|
||||
with open('../dataset/Chinese-medical-dialogue.json', 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 准备CSV数据
|
||||
csv_data = []
|
||||
for item in data:
|
||||
# 提取input和output并去除首尾空白
|
||||
q = item['input'].strip()
|
||||
a = item['output'].strip()
|
||||
|
||||
# 检查长度是否符合要求
|
||||
if len(q) + len(a) < 160:
|
||||
csv_data.append({
|
||||
'history': '[]',
|
||||
'q': q,
|
||||
'a': a
|
||||
})
|
||||
|
||||
# 写入CSV文件
|
||||
with open('../dataset/medical_sft.csv', 'w', newline='', encoding='utf-8') as csvfile:
|
||||
fieldnames = ['history', 'q', 'a']
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
|
||||
writer.writeheader()
|
||||
writer.writerows(csv_data)
|
||||
|
||||
print(f'转换完成,共处理 {len(csv_data)} 条有效数据')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
################
|
||||
# 1: pretrain
|
||||
# 2: sft
|
||||
# 3: RL
|
||||
################
|
||||
process_type = 4
|
||||
|
||||
if process_type == 1:
|
||||
pretrain_process()
|
||||
if process_type == 2:
|
||||
sft_process()
|
||||
if process_type == 3:
|
||||
rl_process()
|
||||
if process_type == 4:
|
||||
lora_dataset()
|
||||
@@ -1,97 +0,0 @@
|
||||
# from datasets import load_dataset
|
||||
#
|
||||
# dataset_paths = [
|
||||
# ['ceval/ceval-exam',
|
||||
# ['computer_network', 'operating_system', 'computer_architecture', 'college_programming', 'college_physics',
|
||||
# 'college_chemistry', 'advanced_mathematics', 'probability_and_statistics', 'discrete_mathematics',
|
||||
# 'electrical_engineer', 'metrology_engineer', 'high_school_mathematics', 'high_school_physics',
|
||||
# 'high_school_chemistry', 'high_school_biology', 'middle_school_mathematics', 'middle_school_biology',
|
||||
# 'middle_school_physics', 'middle_school_chemistry', 'veterinary_medicine', 'college_economics',
|
||||
# 'business_administration', 'marxism', 'mao_zedong_thought', 'education_science', 'teacher_qualification',
|
||||
# 'high_school_politics', 'high_school_geography', 'middle_school_politics', 'middle_school_geography',
|
||||
# 'modern_chinese_history', 'ideological_and_moral_cultivation', 'logic', 'law', 'chinese_language_and_literature',
|
||||
# 'art_studies', 'professional_tour_guide', 'legal_professional', 'high_school_chinese', 'high_school_history',
|
||||
# 'middle_school_history', 'civil_servant', 'sports_science', 'plant_protection', 'basic_medicine',
|
||||
# 'clinical_medicine', 'urban_and_rural_planner', 'accountant', 'fire_engineer',
|
||||
# 'environmental_impact_assessment_engineer', 'tax_accountant', 'physician']], # ceval*
|
||||
# ['haonan-li/cmmlu', [
|
||||
# 'agronomy', 'anatomy', 'ancient_chinese', 'arts', 'astronomy', 'business_ethics',
|
||||
# 'chinese_civil_service_exam', 'chinese_driving_rule', 'chinese_food_culture',
|
||||
# 'chinese_foreign_policy', 'chinese_history', 'chinese_literature',
|
||||
# 'chinese_teacher_qualification', 'clinical_knowledge', 'college_actuarial_science',
|
||||
# 'college_education', 'college_engineering_hydrology', 'college_law',
|
||||
# 'college_mathematics', 'college_medical_statistics', 'college_medicine',
|
||||
# 'computer_science', 'computer_security', 'conceptual_physics',
|
||||
# 'construction_project_management', 'economics', 'education', 'electrical_engineering',
|
||||
# 'elementary_chinese', 'elementary_commonsense', 'elementary_information_and_technology',
|
||||
# 'elementary_mathematics', 'ethnology', 'food_science', 'genetics', 'global_facts',
|
||||
# 'high_school_biology', 'high_school_chemistry', 'high_school_geography',
|
||||
# 'high_school_mathematics', 'high_school_physics', 'high_school_politics',
|
||||
# 'human_sexuality', 'international_law', 'journalism', 'jurisprudence',
|
||||
# 'legal_and_moral_basis', 'logical', 'machine_learning', 'management', 'marketing',
|
||||
# 'marxist_theory', 'modern_chinese', 'nutrition', 'philosophy', 'professional_accounting',
|
||||
# 'professional_law', 'professional_medicine', 'professional_psychology',
|
||||
# 'public_relations', 'security_study', 'sociology', 'sports_science',
|
||||
# 'traditional_chinese_medicine', 'virology', 'world_history', 'world_religions'
|
||||
# ]], # cmmlu*
|
||||
# ['tyouisen/aclue',
|
||||
# ['polysemy_resolution', 'poetry_sentiment_analysis', 'named_entity_recognition', 'basic_ancient_chinese',
|
||||
# 'poetry_context_prediction', 'sentence_segmentation', 'couplet_prediction', 'poetry_appreciate',
|
||||
# 'ancient_chinese_culture', 'ancient_phonetics', 'homographic_character_resolution', 'ancient_literature',
|
||||
# 'ancient_medical', 'poetry_quality_assessment', 'reading_comprehension']], # aclue
|
||||
# ['juletxara/mgsm', ['zh']], # mgsm_direct_zh
|
||||
# ['openbookqa', ['main']], # openbookqa
|
||||
# ['ZoneTwelve/tmmluplus',
|
||||
# ['dentistry', 'traditional_chinese_medicine_clinical_medicine', 'clinical_psychology', 'technical',
|
||||
# 'culinary_skills', 'mechanical', 'logic_reasoning', 'real_estate', 'general_principles_of_law', 'finance_banking',
|
||||
# 'anti_money_laundering', 'ttqav2', 'marketing_management', 'business_management', 'organic_chemistry',
|
||||
# 'advance_chemistry', 'physics', 'secondary_physics', 'human_behavior', 'national_protection', 'jce_humanities',
|
||||
# 'politic_science', 'agriculture', 'official_document_management', 'financial_analysis', 'pharmacy',
|
||||
# 'educational_psychology', 'statistics_and_machine_learning', 'management_accounting', 'introduction_to_law',
|
||||
# 'computer_science', 'veterinary_pathology', 'accounting', 'fire_science', 'optometry', 'insurance_studies',
|
||||
# 'pharmacology', 'taxation', 'education_(profession_level)', 'economics', 'veterinary_pharmacology',
|
||||
# 'nautical_science', 'occupational_therapy_for_psychological_disorders', 'trust_practice', 'geography_of_taiwan',
|
||||
# 'physical_education', 'auditing', 'administrative_law', 'basic_medical_science', 'macroeconomics', 'trade',
|
||||
# 'chinese_language_and_literature', 'tve_design', 'junior_science_exam', 'junior_math_exam', 'junior_chinese_exam',
|
||||
# 'junior_social_studies', 'tve_mathematics', 'tve_chinese_language', 'tve_natural_sciences', 'junior_chemistry',
|
||||
# 'music', 'education', 'three_principles_of_people', 'taiwanese_hokkien', 'engineering_math', 'linear_algebra']]
|
||||
# # tmmluplus
|
||||
#
|
||||
# ]
|
||||
#
|
||||
# for dataset_path in dataset_paths:
|
||||
# for dataset_name in dataset_path[1]:
|
||||
# datasets = load_dataset(dataset_path[0], dataset_name, cache_dir='./test_dataset_cache')
|
||||
#
|
||||
# """
|
||||
# export HF_HUB_OFFLINE=1 && lm_eval --model hf --model_args pretrained=/xxx/minimind/minimind-v2-small/,device=cuda,dtype=auto --tasks ceval* --batch_size 8 --trust_remote_code
|
||||
# """
|
||||
"""
|
||||
$env:HF_HUB_OFFLINE=1; lm_eval --model hf --model_args pretrained=../minimind-v2-small/,device=cuda,dtype=auto --tasks ceval* --batch_size 8 --trust_remote_code
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
|
||||
# 定义要执行的命令
|
||||
command = (
|
||||
'set HF_HUB_OFFLINE=1 & '
|
||||
'lm_eval --model hf --model_args pretrained=../minimind-v2-small/,device=cuda,dtype=auto '
|
||||
'--tasks ceval* --batch_size 8 --trust_remote_code'
|
||||
)
|
||||
|
||||
# 使用 subprocess 执行命令
|
||||
try:
|
||||
process = subprocess.run(
|
||||
command,
|
||||
shell=True,
|
||||
check=True,
|
||||
text=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
# 打印命令的输出
|
||||
print("STDOUT:", process.stdout)
|
||||
print("STDERR:", process.stderr)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"命令执行失败,返回码: {e.returncode}")
|
||||
print("STDERR:", e.stderr)
|
||||
Reference in New Issue
Block a user