[feat] data process

This commit is contained in:
jingyaogong 2026-02-06 01:17:57 +08:00
parent 11a44340ba
commit ccc190da05

View File

@ -1,9 +1,33 @@
from torch.utils.data import Dataset
import torch
import os
import random
from datasets import load_dataset
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def pre_processing_chat(conversations, add_system_ratio=0.2):
SYSTEM_PROMPTS = [
"你是一个知识丰富的AI尽力为用户提供准确的信息。",
"你是minimind一个小巧但有用的语言模型。",
"你是一个专业的AI助手请提供有价值的回答。",
"你是minimind请尽力帮助用户解决问题。",
"你是一个可靠的AI请给出准确的回答。",
"You are a helpful AI assistant.",
"You are minimind, a lightweight intelligent assistant.",
"You are a friendly chatbot. Please answer the user's questions carefully.",
"You are a knowledgeable AI. Try your best to provide accurate information.",
"You are minimind, a small but useful language model."
]
if conversations and conversations[0].get('role') != 'system':
if random.random() < add_system_ratio:
return [{'role': 'system', 'content': random.choice(SYSTEM_PROMPTS)}] + conversations
return conversations
def post_processing_chat(prompt_content, empty_think_ratio=0.1):
if '<think>\n\n</think>\n\n' in prompt_content and random.random() > empty_think_ratio:
prompt_content = prompt_content.replace('<think>\n\n</think>\n\n', '')
return prompt_content
class PretrainDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
@ -37,9 +61,9 @@ class SFTDataset(Dataset):
def __len__(self):
return len(self.samples)
def create_chat_prompt(self, cs):
messages = cs.copy()
tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None
def create_chat_prompt(self, conversations):
messages = conversations.copy()
tools = conversations[0]["functions"] if (conversations and conversations[0]["role"] == "system" and conversations[0].get("functions")) else None
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
@ -67,7 +91,9 @@ class SFTDataset(Dataset):
def __getitem__(self, index):
sample = self.samples[index]
prompt = self.create_chat_prompt(sample['conversations'])
conversations = pre_processing_chat(sample['conversations'])
prompt = self.create_chat_prompt(conversations)
prompt = post_processing_chat(prompt)
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
labels = self.generate_labels(input_ids)
@ -87,22 +113,24 @@ class DPODataset(Dataset):
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids
self.data = load_dataset('json', data_files=file_path, split='train')
self.samples = load_dataset('json', data_files=file_path, split='train')
def __len__(self):
return len(self.data)
return len(self.samples)
def __getitem__(self, index):
item = self.data[index]
chosen = item['chosen'] # 是一个 list里面包含若干 {role, content}
rejected = item['rejected'] # 同上
sample = self.samples[index]
chosen = sample['chosen'] # 是一个 list里面包含若干 {role, content}
rejected = sample['rejected'] # 同上
chosen_prompt = self.tokenizer.apply_chat_template(
chosen, tokenize=False, add_generation_prompt=False
)
chosen_prompt = post_processing_chat(chosen_prompt)
rejected_prompt = self.tokenizer.apply_chat_template(
rejected, tokenize=False, add_generation_prompt=False
)
rejected_prompt = post_processing_chat(rejected_prompt)
chosen_encoding = self.tokenizer(
chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length'
)
@ -169,11 +197,13 @@ class RLAIFDataset(Dataset):
role = 'user' if i % 2 == 0 else 'assistant'
messages.append({"role": role, "content": turn['content']})
answer = turn['content']
return self.tokenizer.apply_chat_template(
prompt = self.tokenizer.apply_chat_template(
messages[:-1],
tokenize=False,
add_generation_prompt=True # 这里需要True
), answer
)
prompt = post_processing_chat(prompt)
return prompt, answer
def __getitem__(self, index):
sample = self.samples[index]