[update] show speed

This commit is contained in:
jingyaogong 2026-01-07 23:33:47 +08:00
parent df89069362
commit 05d0b216f6

View File

@ -1,7 +1,7 @@
import time
import argparse import argparse
import random import random
import warnings import warnings
import numpy as np
import torch import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
@ -43,6 +43,7 @@ def main():
parser.add_argument('--temperature', default=0.85, type=float, help="生成温度控制随机性0-1越大越随机") parser.add_argument('--temperature', default=0.85, type=float, help="生成温度控制随机性0-1越大越随机")
parser.add_argument('--top_p', default=0.85, type=float, help="nucleus采样阈值0-1") parser.add_argument('--top_p', default=0.85, type=float, help="nucleus采样阈值0-1")
parser.add_argument('--historys', default=0, type=int, help="携带历史对话轮数需为偶数0表示不携带历史") parser.add_argument('--historys', default=0, type=int, help="携带历史对话轮数需为偶数0表示不携带历史")
parser.add_argument('--show_speed', default=1, type=int, help="显示decode速度tokens/s")
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备") parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备")
args = parser.parse_args() args = parser.parse_args()
@ -62,10 +63,10 @@ def main():
input_mode = int(input('[0] 自动测试\n[1] 手动输入\n')) input_mode = int(input('[0] 自动测试\n[1] 手动输入\n'))
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
prompt_iter = prompts if input_mode == 0 else iter(lambda: input('👶: '), '') prompt_iter = prompts if input_mode == 0 else iter(lambda: input('💬: '), '')
for prompt in prompt_iter: for prompt in prompt_iter:
setup_seed(2026) # or setup_seed(random.randint(0, 2048)) setup_seed(2026) # or setup_seed(random.randint(0, 2048))
if input_mode == 0: print(f'👶: {prompt}') if input_mode == 0: print(f'💬: {prompt}')
conversation = conversation[-args.historys:] if args.historys else [] conversation = conversation[-args.historys:] if args.historys else []
conversation.append({"role": "user", "content": prompt}) conversation.append({"role": "user", "content": prompt})
@ -74,7 +75,8 @@ def main():
inputs = tokenizer.apply_chat_template(**templates) if args.weight != 'pretrain' else (tokenizer.bos_token + prompt) inputs = tokenizer.apply_chat_template(**templates) if args.weight != 'pretrain' else (tokenizer.bos_token + prompt)
inputs = tokenizer(inputs, return_tensors="pt", truncation=True).to(args.device) inputs = tokenizer(inputs, return_tensors="pt", truncation=True).to(args.device)
print('🤖️: ', end='') print('🤖: ', end='')
st = time.time()
generated_ids = model.generate( generated_ids = model.generate(
inputs=inputs["input_ids"], attention_mask=inputs["attention_mask"], inputs=inputs["input_ids"], attention_mask=inputs["attention_mask"],
max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer, max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer,
@ -83,7 +85,8 @@ def main():
) )
response = tokenizer.decode(generated_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True) response = tokenizer.decode(generated_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
conversation.append({"role": "assistant", "content": response}) conversation.append({"role": "assistant", "content": response})
print('\n\n') gen_tokens = len(generated_ids[0]) - len(inputs["input_ids"][0])
print(f'\n[Speed]: {gen_tokens / (time.time() - st):.2f} tokens/s\n\n') if args.show_speed else print('\n\n')
if __name__ == "__main__": if __name__ == "__main__":
main() main()