[update] align mask

This commit is contained in:
jingyaogong 2026-01-15 11:20:41 +08:00
parent c090b69c4d
commit aa539a824a

View File

@ -35,8 +35,8 @@ class SFTDataset(Dataset):
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = load_dataset('json', data_files=jsonl_path, split='train')
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids
def __len__(self):
return len(self.samples)
@ -62,7 +62,7 @@ class SFTDataset(Dataset):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
break
end += 1
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
for j in range(start, min(end + len(self.eos_id), self.max_length)):
labels[j] = input_ids[j]
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
@ -89,8 +89,8 @@ class DPODataset(Dataset):
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids
self.data = load_dataset('json', data_files=file_path, split='train')
def __len__(self):
@ -146,7 +146,7 @@ class DPODataset(Dataset):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
break
end += 1
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
for j in range(start, min(end + len(self.eos_id), self.max_length)):
loss_mask[j] = 1
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else: