diff --git a/dataset/lm_dataset.py b/dataset/lm_dataset.py
index 1dd1558..cf663d4 100644
--- a/dataset/lm_dataset.py
+++ b/dataset/lm_dataset.py
@@ -1,9 +1,33 @@
from torch.utils.data import Dataset
import torch
import os
+import random
from datasets import load_dataset
os.environ["TOKENIZERS_PARALLELISM"] = "false"
+def pre_processing_chat(conversations, add_system_ratio=0.2):
+ SYSTEM_PROMPTS = [
+ "你是一个知识丰富的AI,尽力为用户提供准确的信息。",
+ "你是minimind,一个小巧但有用的语言模型。",
+ "你是一个专业的AI助手,请提供有价值的回答。",
+ "你是minimind,请尽力帮助用户解决问题。",
+ "你是一个可靠的AI,请给出准确的回答。",
+ "You are a helpful AI assistant.",
+ "You are minimind, a lightweight intelligent assistant.",
+ "You are a friendly chatbot. Please answer the user's questions carefully.",
+ "You are a knowledgeable AI. Try your best to provide accurate information.",
+ "You are minimind, a small but useful language model."
+ ]
+ if conversations and conversations[0].get('role') != 'system':
+ if random.random() < add_system_ratio:
+ return [{'role': 'system', 'content': random.choice(SYSTEM_PROMPTS)}] + conversations
+ return conversations
+
+def post_processing_chat(prompt_content, empty_think_ratio=0.1):
+ if '\n\n\n\n' in prompt_content and random.random() > empty_think_ratio:
+ prompt_content = prompt_content.replace('\n\n\n\n', '')
+ return prompt_content
+
class PretrainDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
@@ -37,9 +61,9 @@ class SFTDataset(Dataset):
def __len__(self):
return len(self.samples)
- def create_chat_prompt(self, cs):
- messages = cs.copy()
- tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None
+ def create_chat_prompt(self, conversations):
+ messages = conversations.copy()
+ tools = conversations[0]["functions"] if (conversations and conversations[0]["role"] == "system" and conversations[0].get("functions")) else None
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
@@ -67,7 +91,9 @@ class SFTDataset(Dataset):
def __getitem__(self, index):
sample = self.samples[index]
- prompt = self.create_chat_prompt(sample['conversations'])
+ conversations = pre_processing_chat(sample['conversations'])
+ prompt = self.create_chat_prompt(conversations)
+ prompt = post_processing_chat(prompt)
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
labels = self.generate_labels(input_ids)
@@ -87,22 +113,24 @@ class DPODataset(Dataset):
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids
- self.data = load_dataset('json', data_files=file_path, split='train')
+ self.samples = load_dataset('json', data_files=file_path, split='train')
def __len__(self):
- return len(self.data)
+ return len(self.samples)
def __getitem__(self, index):
- item = self.data[index]
- chosen = item['chosen'] # 是一个 list,里面包含若干 {role, content}
- rejected = item['rejected'] # 同上
+ sample = self.samples[index]
+ chosen = sample['chosen'] # 是一个 list,里面包含若干 {role, content}
+ rejected = sample['rejected'] # 同上
chosen_prompt = self.tokenizer.apply_chat_template(
chosen, tokenize=False, add_generation_prompt=False
)
+ chosen_prompt = post_processing_chat(chosen_prompt)
rejected_prompt = self.tokenizer.apply_chat_template(
rejected, tokenize=False, add_generation_prompt=False
)
+ rejected_prompt = post_processing_chat(rejected_prompt)
chosen_encoding = self.tokenizer(
chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length'
)
@@ -169,11 +197,13 @@ class RLAIFDataset(Dataset):
role = 'user' if i % 2 == 0 else 'assistant'
messages.append({"role": role, "content": turn['content']})
answer = turn['content']
- return self.tokenizer.apply_chat_template(
+ prompt = self.tokenizer.apply_chat_template(
messages[:-1],
tokenize=False,
add_generation_prompt=True # 这里需要True
- ), answer
+ )
+ prompt = post_processing_chat(prompt)
+ return prompt, answer
def __getitem__(self, index):
sample = self.samples[index]