mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
[fix] repetition_penalty
This commit is contained in:
@@ -266,7 +266,8 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
|
||||
attention_mask = torch.cat([attention_mask, attention_mask.new_ones(attention_mask.shape[0], 1)], -1) if attention_mask is not None else None
|
||||
logits = outputs.logits[:, -1, :] / temperature
|
||||
if repetition_penalty != 1.0:
|
||||
for i in range(input_ids.shape[0]): logits[i, torch.unique(input_ids[i])] /= repetition_penalty
|
||||
for i in range(input_ids.shape[0]):
|
||||
seen = torch.unique(input_ids[i]); score = logits[i, seen]; logits[i, seen] = torch.where(score > 0, score / repetition_penalty, score * repetition_penalty)
|
||||
if top_k > 0:
|
||||
logits[logits < torch.topk(logits, top_k)[0][..., -1, None]] = -float('inf')
|
||||
if top_p < 1.0:
|
||||
|
||||
Reference in New Issue
Block a user