From 7641985d14380f6c74ce0afedf6ab776f428bcb7 Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Tue, 6 Jan 2026 01:20:52 +0800 Subject: [PATCH] [update] simplify loader --- dataset/lm_dataset.py | 33 +++++++++++---------------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/dataset/lm_dataset.py b/dataset/lm_dataset.py index 3641b00..1fa0397 100644 --- a/dataset/lm_dataset.py +++ b/dataset/lm_dataset.py @@ -1,14 +1,9 @@ -import json -import pandas as pd -import numpy as np -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import Dataset import torch import os from datasets import load_dataset - os.environ["TOKENIZERS_PARALLELISM"] = "false" - class PretrainDataset(Dataset): def __init__(self, data_path, tokenizer, max_length=512): super().__init__() @@ -51,7 +46,7 @@ class SFTDataset(Dataset): def __len__(self): return len(self.samples) - def _create_chat_prompt(self, cs): + def create_chat_prompt(self, cs): messages = cs.copy() tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None return self.tokenizer.apply_chat_template( @@ -61,7 +56,7 @@ class SFTDataset(Dataset): tools=tools ) - def _generate_loss_mask(self, input_ids): + def generate_loss_mask(self, input_ids): loss_mask = [0] * len(input_ids) i = 0 while i < len(input_ids): @@ -81,13 +76,10 @@ class SFTDataset(Dataset): def __getitem__(self, index): sample = self.samples[index] - # 构建对话提示 - prompt = self._create_chat_prompt(sample['conversations']) + prompt = self.create_chat_prompt(sample['conversations']) input_ids = self.tokenizer(prompt).input_ids[:self.max_length] input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids)) - - # 生成动态损失掩码 - loss_mask = self._generate_loss_mask(input_ids) + loss_mask = self.generate_loss_mask(input_ids) # 构建训练数据 X = torch.tensor(input_ids[:-1], dtype=torch.long) @@ -136,10 +128,10 @@ class DPODataset(Dataset): ) chosen_input_ids = chosen_encoding['input_ids'] - chosen_loss_mask = self._generate_loss_mask(chosen_input_ids) + chosen_loss_mask = self.generate_loss_mask(chosen_input_ids) rejected_input_ids = rejected_encoding['input_ids'] - rejected_loss_mask = self._generate_loss_mask(rejected_input_ids) + rejected_loss_mask = self.generate_loss_mask(rejected_input_ids) x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long) y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long) mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long) @@ -156,7 +148,7 @@ class DPODataset(Dataset): 'mask_rejected': mask_rejected } - def _generate_loss_mask(self, input_ids): + def generate_loss_mask(self, input_ids): loss_mask = [0] * len(input_ids) i = 0 while i < len(input_ids): @@ -187,8 +179,7 @@ class RLAIFDataset(Dataset): def __len__(self): return len(self.samples) - def _create_chat_prompt(self, conversations): - """构建符合ChatML格式的对话""" + def create_chat_prompt(self, conversations): messages = [] answer = '' for i, turn in enumerate(conversations): @@ -203,14 +194,12 @@ class RLAIFDataset(Dataset): def __getitem__(self, index): sample = self.samples[index] - # 构建对话提示 - prompt, answer = self._create_chat_prompt(sample['conversations']) + prompt, answer = self.create_chat_prompt(sample['conversations']) return { 'prompt': prompt, 'answer': answer } - if __name__ == "__main__": - pass + pass \ No newline at end of file