diff --git a/dataset/lm_dataset.py b/dataset/lm_dataset.py index 89f68f3..3641b00 100644 --- a/dataset/lm_dataset.py +++ b/dataset/lm_dataset.py @@ -4,6 +4,7 @@ import numpy as np from torch.utils.data import Dataset, DataLoader import torch import os +from datasets import load_dataset os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -13,15 +14,7 @@ class PretrainDataset(Dataset): super().__init__() self.tokenizer = tokenizer self.max_length = max_length - self.samples = self.load_data(data_path) - - def load_data(self, path): - samples = [] - with open(path, 'r', encoding='utf-8') as f: - for line_num, line in enumerate(f, 1): - data = json.loads(line.strip()) - samples.append(data) - return samples + self.samples = load_dataset('json', data_files=data_path, split='train') def __len__(self): return len(self.samples) @@ -51,21 +44,13 @@ class SFTDataset(Dataset): super().__init__() self.tokenizer = tokenizer self.max_length = max_length - self.samples = self.load_data(jsonl_path) + self.samples = load_dataset('json', data_files=jsonl_path, split='train') self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids def __len__(self): return len(self.samples) - def load_data(self, path): - samples = [] - with open(path, 'r', encoding='utf-8') as f: - for line_num, line in enumerate(f, 1): - data = json.loads(line.strip()) - samples.append(data) - return samples - def _create_chat_prompt(self, cs): messages = cs.copy() tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None @@ -127,12 +112,7 @@ class DPODataset(Dataset): self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids - with open(file_path, 'r', encoding='utf-8') as f: - self.data = [] - for line in f: - line = line.strip() - obj = json.loads(line) - self.data.append(obj) + self.data = load_dataset('json', data_files=file_path, split='train') def __len__(self): return len(self.data) @@ -200,21 +180,13 @@ class RLAIFDataset(Dataset): super().__init__() self.tokenizer = tokenizer self.max_length = max_length - self.samples = self.load_data(jsonl_path) + self.samples = load_dataset('json', data_files=jsonl_path, split='train') self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids def __len__(self): return len(self.samples) - def load_data(self, path): - samples = [] - with open(path, 'r', encoding='utf-8') as f: - for line_num, line in enumerate(f, 1): - data = json.loads(line.strip()) - samples.append(data) - return samples - def _create_chat_prompt(self, conversations): """构建符合ChatML格式的对话""" messages = []