[feat] stream load data

This commit is contained in:
jingyaogong 2025-12-28 16:58:52 +08:00
parent 7eae14f3ce
commit 4a5c9f5ece

View File

@ -4,6 +4,7 @@ import numpy as np
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
import torch import torch
import os import os
from datasets import load_dataset
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@ -13,15 +14,7 @@ class PretrainDataset(Dataset):
super().__init__() super().__init__()
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.max_length = max_length self.max_length = max_length
self.samples = self.load_data(data_path) self.samples = load_dataset('json', data_files=data_path, split='train')
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)
@ -51,21 +44,13 @@ class SFTDataset(Dataset):
super().__init__() super().__init__()
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.max_length = max_length self.max_length = max_length
self.samples = self.load_data(jsonl_path) self.samples = load_dataset('json', data_files=jsonl_path, split='train')
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def _create_chat_prompt(self, cs): def _create_chat_prompt(self, cs):
messages = cs.copy() messages = cs.copy()
tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None
@ -127,12 +112,7 @@ class DPODataset(Dataset):
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
with open(file_path, 'r', encoding='utf-8') as f: self.data = load_dataset('json', data_files=file_path, split='train')
self.data = []
for line in f:
line = line.strip()
obj = json.loads(line)
self.data.append(obj)
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
@ -200,21 +180,13 @@ class RLAIFDataset(Dataset):
super().__init__() super().__init__()
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.max_length = max_length self.max_length = max_length
self.samples = self.load_data(jsonl_path) self.samples = load_dataset('json', data_files=jsonl_path, split='train')
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def _create_chat_prompt(self, conversations): def _create_chat_prompt(self, conversations):
"""构建符合ChatML格式的对话""" """构建符合ChatML格式的对话"""
messages = [] messages = []