mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
[feat] data process
This commit is contained in:
parent
11a44340ba
commit
ccc190da05
@ -1,9 +1,33 @@
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
import os
|
||||
import random
|
||||
from datasets import load_dataset
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
def pre_processing_chat(conversations, add_system_ratio=0.2):
|
||||
SYSTEM_PROMPTS = [
|
||||
"你是一个知识丰富的AI,尽力为用户提供准确的信息。",
|
||||
"你是minimind,一个小巧但有用的语言模型。",
|
||||
"你是一个专业的AI助手,请提供有价值的回答。",
|
||||
"你是minimind,请尽力帮助用户解决问题。",
|
||||
"你是一个可靠的AI,请给出准确的回答。",
|
||||
"You are a helpful AI assistant.",
|
||||
"You are minimind, a lightweight intelligent assistant.",
|
||||
"You are a friendly chatbot. Please answer the user's questions carefully.",
|
||||
"You are a knowledgeable AI. Try your best to provide accurate information.",
|
||||
"You are minimind, a small but useful language model."
|
||||
]
|
||||
if conversations and conversations[0].get('role') != 'system':
|
||||
if random.random() < add_system_ratio:
|
||||
return [{'role': 'system', 'content': random.choice(SYSTEM_PROMPTS)}] + conversations
|
||||
return conversations
|
||||
|
||||
def post_processing_chat(prompt_content, empty_think_ratio=0.1):
|
||||
if '<think>\n\n</think>\n\n' in prompt_content and random.random() > empty_think_ratio:
|
||||
prompt_content = prompt_content.replace('<think>\n\n</think>\n\n', '')
|
||||
return prompt_content
|
||||
|
||||
class PretrainDataset(Dataset):
|
||||
def __init__(self, data_path, tokenizer, max_length=512):
|
||||
super().__init__()
|
||||
@ -37,9 +61,9 @@ class SFTDataset(Dataset):
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def create_chat_prompt(self, cs):
|
||||
messages = cs.copy()
|
||||
tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None
|
||||
def create_chat_prompt(self, conversations):
|
||||
messages = conversations.copy()
|
||||
tools = conversations[0]["functions"] if (conversations and conversations[0]["role"] == "system" and conversations[0].get("functions")) else None
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
@ -67,7 +91,9 @@ class SFTDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
prompt = self.create_chat_prompt(sample['conversations'])
|
||||
conversations = pre_processing_chat(sample['conversations'])
|
||||
prompt = self.create_chat_prompt(conversations)
|
||||
prompt = post_processing_chat(prompt)
|
||||
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
|
||||
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
|
||||
labels = self.generate_labels(input_ids)
|
||||
@ -87,22 +113,24 @@ class DPODataset(Dataset):
|
||||
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids
|
||||
self.data = load_dataset('json', data_files=file_path, split='train')
|
||||
self.samples = load_dataset('json', data_files=file_path, split='train')
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.data[index]
|
||||
chosen = item['chosen'] # 是一个 list,里面包含若干 {role, content}
|
||||
rejected = item['rejected'] # 同上
|
||||
sample = self.samples[index]
|
||||
chosen = sample['chosen'] # 是一个 list,里面包含若干 {role, content}
|
||||
rejected = sample['rejected'] # 同上
|
||||
chosen_prompt = self.tokenizer.apply_chat_template(
|
||||
chosen, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
chosen_prompt = post_processing_chat(chosen_prompt)
|
||||
|
||||
rejected_prompt = self.tokenizer.apply_chat_template(
|
||||
rejected, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
rejected_prompt = post_processing_chat(rejected_prompt)
|
||||
chosen_encoding = self.tokenizer(
|
||||
chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length'
|
||||
)
|
||||
@ -169,11 +197,13 @@ class RLAIFDataset(Dataset):
|
||||
role = 'user' if i % 2 == 0 else 'assistant'
|
||||
messages.append({"role": role, "content": turn['content']})
|
||||
answer = turn['content']
|
||||
return self.tokenizer.apply_chat_template(
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
messages[:-1],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True # 这里需要True
|
||||
), answer
|
||||
)
|
||||
prompt = post_processing_chat(prompt)
|
||||
return prompt, answer
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user