minimind/dataset/lm_dataset.py
2025-12-28 16:58:52 +08:00

217 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import os
from datasets import load_dataset
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class PretrainDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = load_dataset('json', data_files=data_path, split='train')
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample = self.samples[index]
# 构建输入文本
encoding = self.tokenizer(
str(sample['text']),
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding.input_ids.squeeze()
loss_mask = (input_ids != self.tokenizer.pad_token_id)
X = torch.tensor(input_ids[:-1], dtype=torch.long)
Y = torch.tensor(input_ids[1:], dtype=torch.long)
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)
return X, Y, loss_mask
class SFTDataset(Dataset):
def __init__(self, jsonl_path, tokenizer, max_length=1024):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = load_dataset('json', data_files=jsonl_path, split='train')
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
def __len__(self):
return len(self.samples)
def _create_chat_prompt(self, cs):
messages = cs.copy()
tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
tools=tools
)
def _generate_loss_mask(self, input_ids):
loss_mask = [0] * len(input_ids)
i = 0
while i < len(input_ids):
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
start = i + len(self.bos_id)
end = start
while end < len(input_ids):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
break
end += 1
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
loss_mask[j] = 1
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
i += 1
return loss_mask
def __getitem__(self, index):
sample = self.samples[index]
# 构建对话提示
prompt = self._create_chat_prompt(sample['conversations'])
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
# 生成动态损失掩码
loss_mask = self._generate_loss_mask(input_ids)
# 构建训练数据
X = torch.tensor(input_ids[:-1], dtype=torch.long)
Y = torch.tensor(input_ids[1:], dtype=torch.long)
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) # 对齐预测位置
# # === 打印每个token的掩码情况 ===
# print(f"\n--- Sample {index} Token Loss Mask (length: {len(input_ids)}) ---")
# for i, (token_id, mask) in enumerate(zip(input_ids, loss_mask)):
# token_str = self.tokenizer.decode([token_id], skip_special_tokens=False)
# token_str = token_str.replace('\n', '\\n').replace('\t', '\\t') # 处理换行等不可见字符
# print(f"Token {i:3d}: {token_id:5d} -> '{token_str:10s}' | mask: {mask}")
# print(f"--- End of Sample {index} ---")
# # ================================
return X, Y, loss_mask
class DPODataset(Dataset):
def __init__(self, file_path, tokenizer, max_length=4096):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
self.data = load_dataset('json', data_files=file_path, split='train')
def __len__(self):
return len(self.data)
def __getitem__(self, index):
item = self.data[index]
chosen = item['chosen'] # 是一个 list里面包含若干 {role, content}
rejected = item['rejected'] # 同上
chosen_prompt = self.tokenizer.apply_chat_template(
chosen, tokenize=False, add_generation_prompt=False
)
rejected_prompt = self.tokenizer.apply_chat_template(
rejected, tokenize=False, add_generation_prompt=False
)
chosen_encoding = self.tokenizer(
chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length'
)
rejected_encoding = self.tokenizer(
rejected_prompt, truncation=True, max_length=self.max_length, padding='max_length'
)
chosen_input_ids = chosen_encoding['input_ids']
chosen_loss_mask = self._generate_loss_mask(chosen_input_ids)
rejected_input_ids = rejected_encoding['input_ids']
rejected_loss_mask = self._generate_loss_mask(rejected_input_ids)
x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long)
y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long)
mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long)
x_rejected = torch.tensor(rejected_input_ids[:-1], dtype=torch.long)
y_rejected = torch.tensor(rejected_input_ids[1:], dtype=torch.long)
mask_rejected = torch.tensor(rejected_loss_mask[1:], dtype=torch.long)
return {
'x_chosen': x_chosen,
'y_chosen': y_chosen,
'mask_chosen': mask_chosen,
'x_rejected': x_rejected,
'y_rejected': y_rejected,
'mask_rejected': mask_rejected
}
def _generate_loss_mask(self, input_ids):
loss_mask = [0] * len(input_ids)
i = 0
while i < len(input_ids):
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
start = i + len(self.bos_id)
end = start
while end < len(input_ids):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
break
end += 1
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
loss_mask[j] = 1
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
i += 1
return loss_mask
class RLAIFDataset(Dataset):
def __init__(self, jsonl_path, tokenizer, max_length=1024):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = load_dataset('json', data_files=jsonl_path, split='train')
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
def __len__(self):
return len(self.samples)
def _create_chat_prompt(self, conversations):
"""构建符合ChatML格式的对话"""
messages = []
answer = ''
for i, turn in enumerate(conversations):
role = 'user' if i % 2 == 0 else 'assistant'
messages.append({"role": role, "content": turn['content']})
answer = turn['content']
return self.tokenizer.apply_chat_template(
messages[:-1],
tokenize=False,
add_generation_prompt=True # 这里需要True
), answer
def __getitem__(self, index):
sample = self.samples[index]
# 构建对话提示
prompt, answer = self._create_chat_prompt(sample['conversations'])
return {
'prompt': prompt,
'answer': answer
}
if __name__ == "__main__":
pass