mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
[update] show speed
This commit is contained in:
parent
df89069362
commit
05d0b216f6
13
eval_llm.py
13
eval_llm.py
@ -1,7 +1,7 @@
|
||||
import time
|
||||
import argparse
|
||||
import random
|
||||
import warnings
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
@ -43,6 +43,7 @@ def main():
|
||||
parser.add_argument('--temperature', default=0.85, type=float, help="生成温度,控制随机性(0-1,越大越随机)")
|
||||
parser.add_argument('--top_p', default=0.85, type=float, help="nucleus采样阈值(0-1)")
|
||||
parser.add_argument('--historys', default=0, type=int, help="携带历史对话轮数(需为偶数,0表示不携带历史)")
|
||||
parser.add_argument('--show_speed', default=1, type=int, help="显示decode速度(tokens/s)")
|
||||
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备")
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -62,10 +63,10 @@ def main():
|
||||
input_mode = int(input('[0] 自动测试\n[1] 手动输入\n'))
|
||||
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
prompt_iter = prompts if input_mode == 0 else iter(lambda: input('👶: '), '')
|
||||
prompt_iter = prompts if input_mode == 0 else iter(lambda: input('💬: '), '')
|
||||
for prompt in prompt_iter:
|
||||
setup_seed(2026) # or setup_seed(random.randint(0, 2048))
|
||||
if input_mode == 0: print(f'👶: {prompt}')
|
||||
if input_mode == 0: print(f'💬: {prompt}')
|
||||
conversation = conversation[-args.historys:] if args.historys else []
|
||||
conversation.append({"role": "user", "content": prompt})
|
||||
|
||||
@ -74,7 +75,8 @@ def main():
|
||||
inputs = tokenizer.apply_chat_template(**templates) if args.weight != 'pretrain' else (tokenizer.bos_token + prompt)
|
||||
inputs = tokenizer(inputs, return_tensors="pt", truncation=True).to(args.device)
|
||||
|
||||
print('🤖️: ', end='')
|
||||
print('🤖: ', end='')
|
||||
st = time.time()
|
||||
generated_ids = model.generate(
|
||||
inputs=inputs["input_ids"], attention_mask=inputs["attention_mask"],
|
||||
max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer,
|
||||
@ -83,7 +85,8 @@ def main():
|
||||
)
|
||||
response = tokenizer.decode(generated_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
|
||||
conversation.append({"role": "assistant", "content": response})
|
||||
print('\n\n')
|
||||
gen_tokens = len(generated_ids[0]) - len(inputs["input_ids"][0])
|
||||
print(f'\n[Speed]: {gen_tokens / (time.time() - st):.2f} tokens/s\n\n') if args.show_speed else print('\n\n')
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user