mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
[update] mask log
This commit is contained in:
parent
20a43d7db0
commit
f55d4c32a0
@ -86,12 +86,9 @@ class SFTDataset(Dataset):
|
|||||||
Y = torch.tensor(input_ids[1:], dtype=torch.long)
|
Y = torch.tensor(input_ids[1:], dtype=torch.long)
|
||||||
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) # 对齐预测位置
|
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) # 对齐预测位置
|
||||||
# # === 打印每个token的掩码情况 ===
|
# # === 打印每个token的掩码情况 ===
|
||||||
# print(f"\n--- Sample {index} Token Loss Mask (length: {len(input_ids)}) ---")
|
# print(f"\n--- Sample {index} ---")
|
||||||
# for i, (token_id, mask) in enumerate(zip(input_ids, loss_mask)):
|
# for i, (x, y, m) in enumerate(zip(X, Y, loss_mask)):
|
||||||
# token_str = self.tokenizer.decode([token_id], skip_special_tokens=False)
|
# print(f"{i:3d}: X={self.tokenizer.decode([x])!r:16s} ---> Y={self.tokenizer.decode([y])!r:16s} mask={m}")
|
||||||
# token_str = token_str.replace('\n', '\\n').replace('\t', '\\t') # 处理换行等不可见字符
|
|
||||||
# print(f"Token {i:3d}: {token_id:5d} -> '{token_str:10s}' | mask: {mask}")
|
|
||||||
# print(f"--- End of Sample {index} ---")
|
|
||||||
# # ================================
|
# # ================================
|
||||||
return X, Y, loss_mask
|
return X, Y, loss_mask
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user