minimind/dataset/lm_dataset.py
2026-01-07 22:12:26 +08:00

202 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from torch.utils.data import Dataset
import torch
import os
from datasets import load_dataset
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class PretrainDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = load_dataset('json', data_files=data_path, split='train')
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample = self.samples[index]
# 构建输入文本
encoding = self.tokenizer(
str(sample['text']),
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding.input_ids.squeeze()
loss_mask = (input_ids != self.tokenizer.pad_token_id)
X = torch.tensor(input_ids[:-1], dtype=torch.long)
Y = torch.tensor(input_ids[1:], dtype=torch.long)
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)
return X, Y, loss_mask
class SFTDataset(Dataset):
def __init__(self, jsonl_path, tokenizer, max_length=1024):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = load_dataset('json', data_files=jsonl_path, split='train')
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
def __len__(self):
return len(self.samples)
def create_chat_prompt(self, cs):
messages = cs.copy()
tools = cs[0]["functions"] if (cs and cs[0]["role"] == "system" and cs[0].get("functions")) else None
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
tools=tools
)
def generate_loss_mask(self, input_ids):
loss_mask = [0] * len(input_ids)
i = 0
while i < len(input_ids):
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
start = i + len(self.bos_id)
end = start
while end < len(input_ids):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
break
end += 1
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
loss_mask[j] = 1
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
i += 1
return loss_mask
def __getitem__(self, index):
sample = self.samples[index]
prompt = self.create_chat_prompt(sample['conversations'])
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
loss_mask = self.generate_loss_mask(input_ids)
# 构建训练数据
X = torch.tensor(input_ids[:-1], dtype=torch.long)
Y = torch.tensor(input_ids[1:], dtype=torch.long)
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) # 对齐预测位置
# # === 打印每个token的掩码情况 ===
# print(f"\n--- Sample {index} ---")
# for i, (x, y, m) in enumerate(zip(X, Y, loss_mask)):
# print(f"{i:3d}: X={self.tokenizer.decode([x])!r:16s} ---> Y={self.tokenizer.decode([y])!r:16s} mask={m}")
# # ================================
return X, Y, loss_mask
class DPODataset(Dataset):
def __init__(self, file_path, tokenizer, max_length=4096):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
self.data = load_dataset('json', data_files=file_path, split='train')
def __len__(self):
return len(self.data)
def __getitem__(self, index):
item = self.data[index]
chosen = item['chosen'] # 是一个 list里面包含若干 {role, content}
rejected = item['rejected'] # 同上
chosen_prompt = self.tokenizer.apply_chat_template(
chosen, tokenize=False, add_generation_prompt=False
)
rejected_prompt = self.tokenizer.apply_chat_template(
rejected, tokenize=False, add_generation_prompt=False
)
chosen_encoding = self.tokenizer(
chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length'
)
rejected_encoding = self.tokenizer(
rejected_prompt, truncation=True, max_length=self.max_length, padding='max_length'
)
chosen_input_ids = chosen_encoding['input_ids']
chosen_loss_mask = self.generate_loss_mask(chosen_input_ids)
rejected_input_ids = rejected_encoding['input_ids']
rejected_loss_mask = self.generate_loss_mask(rejected_input_ids)
x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long)
y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long)
mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long)
x_rejected = torch.tensor(rejected_input_ids[:-1], dtype=torch.long)
y_rejected = torch.tensor(rejected_input_ids[1:], dtype=torch.long)
mask_rejected = torch.tensor(rejected_loss_mask[1:], dtype=torch.long)
return {
'x_chosen': x_chosen,
'y_chosen': y_chosen,
'mask_chosen': mask_chosen,
'x_rejected': x_rejected,
'y_rejected': y_rejected,
'mask_rejected': mask_rejected
}
def generate_loss_mask(self, input_ids):
loss_mask = [0] * len(input_ids)
i = 0
while i < len(input_ids):
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
start = i + len(self.bos_id)
end = start
while end < len(input_ids):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
break
end += 1
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
loss_mask[j] = 1
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
i += 1
return loss_mask
class RLAIFDataset(Dataset):
def __init__(self, jsonl_path, tokenizer, max_length=1024):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = load_dataset('json', data_files=jsonl_path, split='train')
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
def __len__(self):
return len(self.samples)
def create_chat_prompt(self, conversations):
messages = []
answer = ''
for i, turn in enumerate(conversations):
role = 'user' if i % 2 == 0 else 'assistant'
messages.append({"role": role, "content": turn['content']})
answer = turn['content']
return self.tokenizer.apply_chat_template(
messages[:-1],
tokenize=False,
add_generation_prompt=True # 这里需要True
), answer
def __getitem__(self, index):
sample = self.samples[index]
prompt, answer = self.create_chat_prompt(sample['conversations'])
return {
'prompt': prompt,
'answer': answer
}
if __name__ == "__main__":
pass