[fix] max length

This commit is contained in:
jingyaogong 2026-01-17 13:26:14 +08:00
parent 714abcf802
commit 3a5aba82db

View File

@ -16,8 +16,8 @@ class PretrainDataset(Dataset):
def __getitem__(self, index):
sample = self.samples[index]
tokens = self.tokenizer(str(sample['text']), add_special_tokens=False).input_ids
tokens = [self.tokenizer.bos_token_id] + tokens[:self.max_length - 2] + [self.tokenizer.eos_token_id]
tokens = self.tokenizer(str(sample['text']), add_special_tokens=False, max_length=self.max_length - 2, truncation=True).input_ids
tokens = [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
input_ids = tokens + [self.tokenizer.pad_token_id] * (self.max_length - len(tokens))
input_ids = torch.tensor(input_ids, dtype=torch.long)
labels = input_ids.clone()