diff --git a/model/model_minimind.py b/model/model_minimind.py index 35bb66f..dc080e5 100755 --- a/model/model_minimind.py +++ b/model/model_minimind.py @@ -266,7 +266,8 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin): attention_mask = torch.cat([attention_mask, attention_mask.new_ones(attention_mask.shape[0], 1)], -1) if attention_mask is not None else None logits = outputs.logits[:, -1, :] / temperature if repetition_penalty != 1.0: - for i in range(input_ids.shape[0]): logits[i, torch.unique(input_ids[i])] /= repetition_penalty + for i in range(input_ids.shape[0]): + seen = torch.unique(input_ids[i]); score = logits[i, seen]; logits[i, seen] = torch.where(score > 0, score / repetition_penalty, score * repetition_penalty) if top_k > 0: logits[logits < torch.topk(logits, top_k)[0][..., -1, None]] = -float('inf') if top_p < 1.0: