[update] simplify loader

This commit is contained in:
jingyaogong 2026-01-06 01:20:52 +08:00
parent 0b4a8ad4aa
commit 7641985d14

View File

@ -1,14 +1,9 @@
import json
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset
import torch
import os
from datasets import load_dataset
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class PretrainDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
@ -51,7 +46,7 @@ class SFTDataset(Dataset):
def __len__(self):
return len(self.samples)
def _create_chat_prompt(self, cs):
def create_chat_prompt(self, cs):
messages = cs.copy()
tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None
return self.tokenizer.apply_chat_template(
@ -61,7 +56,7 @@ class SFTDataset(Dataset):
tools=tools
)
def _generate_loss_mask(self, input_ids):
def generate_loss_mask(self, input_ids):
loss_mask = [0] * len(input_ids)
i = 0
while i < len(input_ids):
@ -81,13 +76,10 @@ class SFTDataset(Dataset):
def __getitem__(self, index):
sample = self.samples[index]
# 构建对话提示
prompt = self._create_chat_prompt(sample['conversations'])
prompt = self.create_chat_prompt(sample['conversations'])
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
# 生成动态损失掩码
loss_mask = self._generate_loss_mask(input_ids)
loss_mask = self.generate_loss_mask(input_ids)
# 构建训练数据
X = torch.tensor(input_ids[:-1], dtype=torch.long)
@ -136,10 +128,10 @@ class DPODataset(Dataset):
)
chosen_input_ids = chosen_encoding['input_ids']
chosen_loss_mask = self._generate_loss_mask(chosen_input_ids)
chosen_loss_mask = self.generate_loss_mask(chosen_input_ids)
rejected_input_ids = rejected_encoding['input_ids']
rejected_loss_mask = self._generate_loss_mask(rejected_input_ids)
rejected_loss_mask = self.generate_loss_mask(rejected_input_ids)
x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long)
y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long)
mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long)
@ -156,7 +148,7 @@ class DPODataset(Dataset):
'mask_rejected': mask_rejected
}
def _generate_loss_mask(self, input_ids):
def generate_loss_mask(self, input_ids):
loss_mask = [0] * len(input_ids)
i = 0
while i < len(input_ids):
@ -187,8 +179,7 @@ class RLAIFDataset(Dataset):
def __len__(self):
return len(self.samples)
def _create_chat_prompt(self, conversations):
"""构建符合ChatML格式的对话"""
def create_chat_prompt(self, conversations):
messages = []
answer = ''
for i, turn in enumerate(conversations):
@ -203,14 +194,12 @@ class RLAIFDataset(Dataset):
def __getitem__(self, index):
sample = self.samples[index]
# 构建对话提示
prompt, answer = self._create_chat_prompt(sample['conversations'])
prompt, answer = self.create_chat_prompt(sample['conversations'])
return {
'prompt': prompt,
'answer': answer
}
if __name__ == "__main__":
pass
pass