mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
[feat] stream load data
This commit is contained in:
parent
7eae14f3ce
commit
4a5c9f5ece
@ -4,6 +4,7 @@ import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch
|
||||
import os
|
||||
from datasets import load_dataset
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
@ -13,15 +14,7 @@ class PretrainDataset(Dataset):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.samples = self.load_data(data_path)
|
||||
|
||||
def load_data(self, path):
|
||||
samples = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
data = json.loads(line.strip())
|
||||
samples.append(data)
|
||||
return samples
|
||||
self.samples = load_dataset('json', data_files=data_path, split='train')
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
@ -51,21 +44,13 @@ class SFTDataset(Dataset):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.samples = self.load_data(jsonl_path)
|
||||
self.samples = load_dataset('json', data_files=jsonl_path, split='train')
|
||||
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def load_data(self, path):
|
||||
samples = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
data = json.loads(line.strip())
|
||||
samples.append(data)
|
||||
return samples
|
||||
|
||||
def _create_chat_prompt(self, cs):
|
||||
messages = cs.copy()
|
||||
tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None
|
||||
@ -127,12 +112,7 @@ class DPODataset(Dataset):
|
||||
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
self.data = []
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
obj = json.loads(line)
|
||||
self.data.append(obj)
|
||||
self.data = load_dataset('json', data_files=file_path, split='train')
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
@ -200,21 +180,13 @@ class RLAIFDataset(Dataset):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.samples = self.load_data(jsonl_path)
|
||||
self.samples = load_dataset('json', data_files=jsonl_path, split='train')
|
||||
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def load_data(self, path):
|
||||
samples = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
data = json.loads(line.strip())
|
||||
samples.append(data)
|
||||
return samples
|
||||
|
||||
def _create_chat_prompt(self, conversations):
|
||||
"""构建符合ChatML格式的对话"""
|
||||
messages = []
|
||||
|
||||
Loading…
Reference in New Issue
Block a user