[fix] repetition_penalty

This commit is contained in:
jingyaogong
2026-05-07 19:08:52 +08:00
parent 802c15b2b4
commit dddedc6881
+2 -1
View File
@@ -266,7 +266,8 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
attention_mask = torch.cat([attention_mask, attention_mask.new_ones(attention_mask.shape[0], 1)], -1) if attention_mask is not None else None
logits = outputs.logits[:, -1, :] / temperature
if repetition_penalty != 1.0:
for i in range(input_ids.shape[0]): logits[i, torch.unique(input_ids[i])] /= repetition_penalty
for i in range(input_ids.shape[0]):
seen = torch.unique(input_ids[i]); score = logits[i, seen]; logits[i, seen] = torch.where(score > 0, score / repetition_penalty, score * repetition_penalty)
if top_k > 0:
logits[logits < torch.topk(logits, top_k)[0][..., -1, None]] = -float('inf')
if top_p < 1.0: