mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-23 15:58:15 +08:00
[update] simplify loader
This commit is contained in:
parent
0b4a8ad4aa
commit
7641985d14
@ -1,14 +1,9 @@
|
||||
import json
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
import os
|
||||
from datasets import load_dataset
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
class PretrainDataset(Dataset):
|
||||
def __init__(self, data_path, tokenizer, max_length=512):
|
||||
super().__init__()
|
||||
@ -51,7 +46,7 @@ class SFTDataset(Dataset):
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def _create_chat_prompt(self, cs):
|
||||
def create_chat_prompt(self, cs):
|
||||
messages = cs.copy()
|
||||
tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None
|
||||
return self.tokenizer.apply_chat_template(
|
||||
@ -61,7 +56,7 @@ class SFTDataset(Dataset):
|
||||
tools=tools
|
||||
)
|
||||
|
||||
def _generate_loss_mask(self, input_ids):
|
||||
def generate_loss_mask(self, input_ids):
|
||||
loss_mask = [0] * len(input_ids)
|
||||
i = 0
|
||||
while i < len(input_ids):
|
||||
@ -81,13 +76,10 @@ class SFTDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
# 构建对话提示
|
||||
prompt = self._create_chat_prompt(sample['conversations'])
|
||||
prompt = self.create_chat_prompt(sample['conversations'])
|
||||
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
|
||||
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
|
||||
|
||||
# 生成动态损失掩码
|
||||
loss_mask = self._generate_loss_mask(input_ids)
|
||||
loss_mask = self.generate_loss_mask(input_ids)
|
||||
|
||||
# 构建训练数据
|
||||
X = torch.tensor(input_ids[:-1], dtype=torch.long)
|
||||
@ -136,10 +128,10 @@ class DPODataset(Dataset):
|
||||
)
|
||||
|
||||
chosen_input_ids = chosen_encoding['input_ids']
|
||||
chosen_loss_mask = self._generate_loss_mask(chosen_input_ids)
|
||||
chosen_loss_mask = self.generate_loss_mask(chosen_input_ids)
|
||||
|
||||
rejected_input_ids = rejected_encoding['input_ids']
|
||||
rejected_loss_mask = self._generate_loss_mask(rejected_input_ids)
|
||||
rejected_loss_mask = self.generate_loss_mask(rejected_input_ids)
|
||||
x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long)
|
||||
y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long)
|
||||
mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long)
|
||||
@ -156,7 +148,7 @@ class DPODataset(Dataset):
|
||||
'mask_rejected': mask_rejected
|
||||
}
|
||||
|
||||
def _generate_loss_mask(self, input_ids):
|
||||
def generate_loss_mask(self, input_ids):
|
||||
loss_mask = [0] * len(input_ids)
|
||||
i = 0
|
||||
while i < len(input_ids):
|
||||
@ -187,8 +179,7 @@ class RLAIFDataset(Dataset):
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def _create_chat_prompt(self, conversations):
|
||||
"""构建符合ChatML格式的对话"""
|
||||
def create_chat_prompt(self, conversations):
|
||||
messages = []
|
||||
answer = ''
|
||||
for i, turn in enumerate(conversations):
|
||||
@ -203,14 +194,12 @@ class RLAIFDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
# 构建对话提示
|
||||
prompt, answer = self._create_chat_prompt(sample['conversations'])
|
||||
prompt, answer = self.create_chat_prompt(sample['conversations'])
|
||||
|
||||
return {
|
||||
'prompt': prompt,
|
||||
'answer': answer
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
pass
|
||||
Loading…
Reference in New Issue
Block a user