mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
[fix] max length
This commit is contained in:
parent
714abcf802
commit
3a5aba82db
@ -16,8 +16,8 @@ class PretrainDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
tokens = self.tokenizer(str(sample['text']), add_special_tokens=False).input_ids
|
||||
tokens = [self.tokenizer.bos_token_id] + tokens[:self.max_length - 2] + [self.tokenizer.eos_token_id]
|
||||
tokens = self.tokenizer(str(sample['text']), add_special_tokens=False, max_length=self.max_length - 2, truncation=True).input_ids
|
||||
tokens = [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
|
||||
input_ids = tokens + [self.tokenizer.pad_token_id] * (self.max_length - len(tokens))
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
||||
labels = input_ids.clone()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user