mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
Merge f5079ce090 into 05d0b216f6
This commit is contained in:
commit
db82ca8bce
@ -79,9 +79,12 @@ def main():
|
||||
st = time.time()
|
||||
generated_ids = model.generate(
|
||||
inputs=inputs["input_ids"], attention_mask=inputs["attention_mask"],
|
||||
max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer,
|
||||
max_new_tokens=min(args.max_new_tokens, 200),
|
||||
do_sample=True, streamer=streamer,
|
||||
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,
|
||||
top_p=args.top_p, temperature=args.temperature, repetition_penalty=1.0
|
||||
top_p=args.top_p, temperature=args.temperature,
|
||||
repetition_penalty=1.15, # ← 只保留这一行
|
||||
early_stopping=True,
|
||||
)
|
||||
response = tokenizer.decode(generated_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
|
||||
conversation.append({"role": "assistant", "content": response})
|
||||
@ -89,4 +92,4 @@ def main():
|
||||
print(f'\n[Speed]: {gen_tokens / (time.time() - st):.2f} tokens/s\n\n') if args.show_speed else print('\n\n')
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user