Fix DPO loss_mask boundary (include first assistant token)

Fix off-by-one in generate_loss_mask so the first assistant token contributes to loss in DPO.
This commit is contained in:
xiao-baia 2026-01-07 21:00:46 +08:00 committed by GitHub
parent 20a43d7db0
commit c972c4e090
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -159,7 +159,7 @@ class DPODataset(Dataset):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
break
end += 1
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
for j in range(start, min(end + len(self.eos_id) + 1, self.max_length)):
loss_mask[j] = 1
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
@ -202,4 +202,4 @@ class RLAIFDataset(Dataset):
}
if __name__ == "__main__":
pass
pass