From 3a5aba82dbecb8b31349340f0e7cf8df94b9aa38 Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Sat, 17 Jan 2026 13:26:14 +0800 Subject: [PATCH] [fix] max length --- dataset/lm_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dataset/lm_dataset.py b/dataset/lm_dataset.py index 1a57ce9..1dd1558 100644 --- a/dataset/lm_dataset.py +++ b/dataset/lm_dataset.py @@ -16,8 +16,8 @@ class PretrainDataset(Dataset): def __getitem__(self, index): sample = self.samples[index] - tokens = self.tokenizer(str(sample['text']), add_special_tokens=False).input_ids - tokens = [self.tokenizer.bos_token_id] + tokens[:self.max_length - 2] + [self.tokenizer.eos_token_id] + tokens = self.tokenizer(str(sample['text']), add_special_tokens=False, max_length=self.max_length - 2, truncation=True).input_ids + tokens = [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id] input_ids = tokens + [self.tokenizer.pad_token_id] * (self.max_length - len(tokens)) input_ids = torch.tensor(input_ids, dtype=torch.long) labels = input_ids.clone()