mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
[update] align mask
This commit is contained in:
parent
c090b69c4d
commit
aa539a824a
@ -35,8 +35,8 @@ class SFTDataset(Dataset):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.samples = load_dataset('json', data_files=jsonl_path, split='train')
|
||||
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
|
||||
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
@ -62,7 +62,7 @@ class SFTDataset(Dataset):
|
||||
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
|
||||
break
|
||||
end += 1
|
||||
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
|
||||
for j in range(start, min(end + len(self.eos_id), self.max_length)):
|
||||
labels[j] = input_ids[j]
|
||||
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
|
||||
else:
|
||||
@ -89,8 +89,8 @@ class DPODataset(Dataset):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
|
||||
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids
|
||||
self.data = load_dataset('json', data_files=file_path, split='train')
|
||||
|
||||
def __len__(self):
|
||||
@ -146,7 +146,7 @@ class DPODataset(Dataset):
|
||||
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
|
||||
break
|
||||
end += 1
|
||||
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
|
||||
for j in range(start, min(end + len(self.eos_id), self.max_length)):
|
||||
loss_mask[j] = 1
|
||||
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user