mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
[feat] stream load data
This commit is contained in:
parent
7eae14f3ce
commit
4a5c9f5ece
@ -4,6 +4,7 @@ import numpy as np
|
|||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
@ -13,15 +14,7 @@ class PretrainDataset(Dataset):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.samples = self.load_data(data_path)
|
self.samples = load_dataset('json', data_files=data_path, split='train')
|
||||||
|
|
||||||
def load_data(self, path):
|
|
||||||
samples = []
|
|
||||||
with open(path, 'r', encoding='utf-8') as f:
|
|
||||||
for line_num, line in enumerate(f, 1):
|
|
||||||
data = json.loads(line.strip())
|
|
||||||
samples.append(data)
|
|
||||||
return samples
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.samples)
|
return len(self.samples)
|
||||||
@ -51,21 +44,13 @@ class SFTDataset(Dataset):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.samples = self.load_data(jsonl_path)
|
self.samples = load_dataset('json', data_files=jsonl_path, split='train')
|
||||||
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
|
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
|
||||||
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
|
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.samples)
|
return len(self.samples)
|
||||||
|
|
||||||
def load_data(self, path):
|
|
||||||
samples = []
|
|
||||||
with open(path, 'r', encoding='utf-8') as f:
|
|
||||||
for line_num, line in enumerate(f, 1):
|
|
||||||
data = json.loads(line.strip())
|
|
||||||
samples.append(data)
|
|
||||||
return samples
|
|
||||||
|
|
||||||
def _create_chat_prompt(self, cs):
|
def _create_chat_prompt(self, cs):
|
||||||
messages = cs.copy()
|
messages = cs.copy()
|
||||||
tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None
|
tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None
|
||||||
@ -127,12 +112,7 @@ class DPODataset(Dataset):
|
|||||||
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||||
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
|
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
|
||||||
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
|
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
self.data = load_dataset('json', data_files=file_path, split='train')
|
||||||
self.data = []
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
obj = json.loads(line)
|
|
||||||
self.data.append(obj)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
@ -200,21 +180,13 @@ class RLAIFDataset(Dataset):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.samples = self.load_data(jsonl_path)
|
self.samples = load_dataset('json', data_files=jsonl_path, split='train')
|
||||||
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
|
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
|
||||||
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
|
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.samples)
|
return len(self.samples)
|
||||||
|
|
||||||
def load_data(self, path):
|
|
||||||
samples = []
|
|
||||||
with open(path, 'r', encoding='utf-8') as f:
|
|
||||||
for line_num, line in enumerate(f, 1):
|
|
||||||
data = json.loads(line.strip())
|
|
||||||
samples.append(data)
|
|
||||||
return samples
|
|
||||||
|
|
||||||
def _create_chat_prompt(self, conversations):
|
def _create_chat_prompt(self, conversations):
|
||||||
"""构建符合ChatML格式的对话"""
|
"""构建符合ChatML格式的对话"""
|
||||||
messages = []
|
messages = []
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user