This commit is contained in:
LearnMan 2026-01-08 19:08:28 +00:00 committed by GitHub
commit db82ca8bce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -79,9 +79,12 @@ def main():
st = time.time()
generated_ids = model.generate(
inputs=inputs["input_ids"], attention_mask=inputs["attention_mask"],
max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer,
max_new_tokens=min(args.max_new_tokens, 200),
do_sample=True, streamer=streamer,
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,
top_p=args.top_p, temperature=args.temperature, repetition_penalty=1.0
top_p=args.top_p, temperature=args.temperature,
repetition_penalty=1.15, # ← 只保留这一行
early_stopping=True,
)
response = tokenizer.decode(generated_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
conversation.append({"role": "assistant", "content": response})
@ -89,4 +92,4 @@ def main():
print(f'\n[Speed]: {gen_tokens / (time.time() - st):.2f} tokens/s\n\n') if args.show_speed else print('\n\n')
if __name__ == "__main__":
main()
main()