This commit is contained in:
ParthSareen 2025-11-11 18:06:40 -08:00
parent 8ffcb08465
commit 522804fb0c
2 changed files with 21 additions and 23 deletions

View File

@ -4,29 +4,28 @@ import ollama
def print_logprobs(logprobs: Iterable[dict], label: str) -> None: def print_logprobs(logprobs: Iterable[dict], label: str) -> None:
print(f"\n{label}:") print(f'\n{label}:')
for entry in logprobs: for entry in logprobs:
token = entry.get("token", "") token = entry.get('token', '')
logprob = entry.get("logprob") logprob = entry.get('logprob')
print(f" token={token!r:<12} logprob={logprob:.3f}") print(f' token={token!r:<12} logprob={logprob:.3f}')
for alt in entry.get("top_logprobs", []): for alt in entry.get('top_logprobs', []):
if alt['token'] != token: if alt['token'] != token:
print(f" alt -> {alt['token']!r:<12} ({alt['logprob']:.3f})") print(f' alt -> {alt["token"]!r:<12} ({alt["logprob"]:.3f})')
messages = [ messages = [
{ {
"role": "user", 'role': 'user',
"content": "hi! be concise.", 'content': 'hi! be concise.',
}, },
] ]
response = ollama.chat( response = ollama.chat(
model="gemma3", model='gemma3',
messages=messages, messages=messages,
logprobs=True, logprobs=True,
top_logprobs=3, top_logprobs=3,
) )
print("Chat response:", response["message"]["content"]) print('Chat response:', response['message']['content'])
print_logprobs(response.get("logprobs", []), "chat logprobs") print_logprobs(response.get('logprobs', []), 'chat logprobs')

View File

@ -4,22 +4,21 @@ import ollama
def print_logprobs(logprobs: Iterable[dict], label: str) -> None: def print_logprobs(logprobs: Iterable[dict], label: str) -> None:
print(f"\n{label}:") print(f'\n{label}:')
for entry in logprobs: for entry in logprobs:
token = entry.get("token", "") token = entry.get('token', '')
logprob = entry.get("logprob") logprob = entry.get('logprob')
print(f" token={token!r:<12} logprob={logprob:.3f}") print(f' token={token!r:<12} logprob={logprob:.3f}')
for alt in entry.get("top_logprobs", []): for alt in entry.get('top_logprobs', []):
if alt['token'] != token: if alt['token'] != token:
print(f" alt -> {alt['token']!r:<12} ({alt['logprob']:.3f})") print(f' alt -> {alt["token"]!r:<12} ({alt["logprob"]:.3f})')
response = ollama.generate( response = ollama.generate(
model="gemma3", model='gemma3',
prompt="hi! be concise.", prompt='hi! be concise.',
logprobs=True, logprobs=True,
top_logprobs=3, top_logprobs=3,
) )
print("Generate response:", response["response"]) print('Generate response:', response['response'])
print_logprobs(response.get("logprobs", []), "generate logprobs") print_logprobs(response.get('logprobs', []), 'generate logprobs')