mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
8 lines
216 B
Python
8 lines
216 B
Python
def ppl(logits, output_ids):
|
|
"""
|
|
Calculate per-token perplexity.
|
|
"""
|
|
nlls = -logits.log_softmax(dim=-1)
|
|
ppls = nlls.gather(-1, output_ids.long().unsqueeze(-1))
|
|
return ppls.mean().exp().item()
|