[update] mask log

This commit is contained in:
jingyaogong 2026-01-07 22:12:26 +08:00
parent 20a43d7db0
commit f55d4c32a0

View File

@ -86,12 +86,9 @@ class SFTDataset(Dataset):
Y = torch.tensor(input_ids[1:], dtype=torch.long) Y = torch.tensor(input_ids[1:], dtype=torch.long)
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) # 对齐预测位置 loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) # 对齐预测位置
# # === 打印每个token的掩码情况 === # # === 打印每个token的掩码情况 ===
# print(f"\n--- Sample {index} Token Loss Mask (length: {len(input_ids)}) ---") # print(f"\n--- Sample {index} ---")
# for i, (token_id, mask) in enumerate(zip(input_ids, loss_mask)): # for i, (x, y, m) in enumerate(zip(X, Y, loss_mask)):
# token_str = self.tokenizer.decode([token_id], skip_special_tokens=False) # print(f"{i:3d}: X={self.tokenizer.decode([x])!r:16s} ---> Y={self.tokenizer.decode([y])!r:16s} mask={m}")
# token_str = token_str.replace('\n', '\\n').replace('\t', '\\t') # 处理换行等不可见字符
# print(f"Token {i:3d}: {token_id:5d} -> '{token_str:10s}' | mask: {mask}")
# print(f"--- End of Sample {index} ---")
# # ================================ # # ================================
return X, Y, loss_mask return X, Y, loss_mask