diff --git a/dataset/lm_dataset.py b/dataset/lm_dataset.py index 21e321c..51db696 100644 --- a/dataset/lm_dataset.py +++ b/dataset/lm_dataset.py @@ -35,8 +35,8 @@ class SFTDataset(Dataset): self.tokenizer = tokenizer self.max_length = max_length self.samples = load_dataset('json', data_files=jsonl_path, split='train') - self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids - self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids + self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids + self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids def __len__(self): return len(self.samples) @@ -62,7 +62,7 @@ class SFTDataset(Dataset): if input_ids[end:end + len(self.eos_id)] == self.eos_id: break end += 1 - for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)): + for j in range(start, min(end + len(self.eos_id), self.max_length)): labels[j] = input_ids[j] i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids) else: @@ -89,8 +89,8 @@ class DPODataset(Dataset): self.tokenizer = tokenizer self.max_length = max_length self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 - self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids - self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids + self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids + self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids self.data = load_dataset('json', data_files=file_path, split='train') def __len__(self): @@ -146,7 +146,7 @@ class DPODataset(Dataset): if input_ids[end:end + len(self.eos_id)] == self.eos_id: break end += 1 - for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)): + for j in range(start, min(end + len(self.eos_id), self.max_length)): loss_mask[j] = 1 i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids) else: