mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
64 lines
2.5 KiB
Python
64 lines
2.5 KiB
Python
#!/usr/bin/python
|
|
|
|
import argparse
|
|
import json
|
|
|
|
from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--dataset',
|
|
type=str,
|
|
required=True,
|
|
help='Dataset path used for the test.')
|
|
parser.add_argument('--max_input_len',
|
|
type=int,
|
|
required=True,
|
|
help='Specify max input length')
|
|
parser.add_argument('--tokenizer_dir',
|
|
type=str,
|
|
required=True,
|
|
help='Specify tokenizer directory')
|
|
parser.add_argument('--tokenizer_type',
|
|
type=str,
|
|
default='auto',
|
|
required=False,
|
|
choices=['auto', 't5', 'llama'],
|
|
help='Specify tokenizer type')
|
|
parser.add_argument('--output',
|
|
type=str,
|
|
default='preprocessed_dataset.json',
|
|
help='Preprocessed dataset path.')
|
|
FLAGS = parser.parse_args()
|
|
|
|
if FLAGS.tokenizer_type == 't5':
|
|
tokenizer = T5Tokenizer(vocab_file=FLAGS.tokenizer_dir,
|
|
padding_side='left')
|
|
elif FLAGS.tokenizer_type == 'auto':
|
|
tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer_dir,
|
|
padding_side='left')
|
|
elif FLAGS.tokenizer_type == 'llama':
|
|
tokenizer = LlamaTokenizer.from_pretrained(FLAGS.tokenizer_dir,
|
|
legacy=False,
|
|
padding_side='left')
|
|
else:
|
|
raise AttributeError(
|
|
f'Unexpected tokenizer type: {FLAGS.tokenizer_type}')
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
results = []
|
|
with open(FLAGS.dataset, 'r') as f:
|
|
data_dict = json.load(f)
|
|
for req in data_dict:
|
|
prompt = req['input'] + ' ' + req['instruction']
|
|
output = req['output']
|
|
line = tokenizer.encode(prompt)
|
|
if len(line) > FLAGS.max_input_len:
|
|
continue
|
|
# 1.3 is a magic number that converts number of words to number of tokens
|
|
output_len = int(len(output.split(' ')) * 1.3)
|
|
results.append({'input_ids': line, 'output_len': output_len})
|
|
|
|
with open(FLAGS.output, 'w') as f:
|
|
json.dump(results, f)
|