mirror of
https://github.com/ollama/ollama-python.git
synced 2026-01-13 13:47:17 +08:00
client/types: add logprobs support (#601)
This commit is contained in:
parent
9ddd5f0182
commit
0008226fda
31
examples/chat-logprobs.py
Normal file
31
examples/chat-logprobs.py
Normal file
@ -0,0 +1,31 @@
|
||||
from typing import Iterable
|
||||
|
||||
import ollama
|
||||
|
||||
|
||||
def print_logprobs(logprobs: Iterable[dict], label: str) -> None:
|
||||
print(f'\n{label}:')
|
||||
for entry in logprobs:
|
||||
token = entry.get('token', '')
|
||||
logprob = entry.get('logprob')
|
||||
print(f' token={token!r:<12} logprob={logprob:.3f}')
|
||||
for alt in entry.get('top_logprobs', []):
|
||||
if alt['token'] != token:
|
||||
print(f' alt -> {alt["token"]!r:<12} ({alt["logprob"]:.3f})')
|
||||
|
||||
|
||||
messages = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'hi! be concise.',
|
||||
},
|
||||
]
|
||||
|
||||
response = ollama.chat(
|
||||
model='gemma3',
|
||||
messages=messages,
|
||||
logprobs=True,
|
||||
top_logprobs=3,
|
||||
)
|
||||
print('Chat response:', response['message']['content'])
|
||||
print_logprobs(response.get('logprobs', []), 'chat logprobs')
|
||||
@ -15,7 +15,8 @@ messages = [
|
||||
},
|
||||
{
|
||||
'role': 'assistant',
|
||||
'content': 'The weather in Tokyo is typically warm and humid during the summer months, with temperatures often exceeding 30°C (86°F). The city experiences a rainy season from June to September, with heavy rainfall and occasional typhoons. Winter is mild, with temperatures rarely dropping below freezing. The city is known for its high-tech and vibrant culture, with many popular tourist attractions such as the Tokyo Tower, Senso-ji Temple, and the bustling Shibuya district.',
|
||||
'content': """The weather in Tokyo is typically warm and humid during the summer months, with temperatures often exceeding 30°C (86°F). The city experiences a rainy season from June to September, with heavy rainfall and occasional typhoons. Winter is mild, with temperatures
|
||||
rarely dropping below freezing. The city is known for its high-tech and vibrant culture, with many popular tourist attractions such as the Tokyo Tower, Senso-ji Temple, and the bustling Shibuya district.""",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
24
examples/generate-logprobs.py
Normal file
24
examples/generate-logprobs.py
Normal file
@ -0,0 +1,24 @@
|
||||
from typing import Iterable
|
||||
|
||||
import ollama
|
||||
|
||||
|
||||
def print_logprobs(logprobs: Iterable[dict], label: str) -> None:
|
||||
print(f'\n{label}:')
|
||||
for entry in logprobs:
|
||||
token = entry.get('token', '')
|
||||
logprob = entry.get('logprob')
|
||||
print(f' token={token!r:<12} logprob={logprob:.3f}')
|
||||
for alt in entry.get('top_logprobs', []):
|
||||
if alt['token'] != token:
|
||||
print(f' alt -> {alt["token"]!r:<12} ({alt["logprob"]:.3f})')
|
||||
|
||||
|
||||
response = ollama.generate(
|
||||
model='gemma3',
|
||||
prompt='hi! be concise.',
|
||||
logprobs=True,
|
||||
top_logprobs=3,
|
||||
)
|
||||
print('Generate response:', response['response'])
|
||||
print_logprobs(response.get('logprobs', []), 'generate logprobs')
|
||||
@ -200,6 +200,8 @@ class Client(BaseClient):
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: Literal[False] = False,
|
||||
think: Optional[bool] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
raw: bool = False,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||
@ -219,6 +221,8 @@ class Client(BaseClient):
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: Literal[True] = True,
|
||||
think: Optional[bool] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
raw: bool = False,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||
@ -237,6 +241,8 @@ class Client(BaseClient):
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: bool = False,
|
||||
think: Optional[bool] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
raw: Optional[bool] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||
@ -266,6 +272,8 @@ class Client(BaseClient):
|
||||
context=context,
|
||||
stream=stream,
|
||||
think=think,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
raw=raw,
|
||||
format=format,
|
||||
images=list(_copy_images(images)) if images else None,
|
||||
@ -284,6 +292,8 @@ class Client(BaseClient):
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||
stream: Literal[False] = False,
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
@ -298,6 +308,8 @@ class Client(BaseClient):
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||
stream: Literal[True] = True,
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
@ -311,6 +323,8 @@ class Client(BaseClient):
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||
stream: bool = False,
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
@ -358,6 +372,8 @@ class Client(BaseClient):
|
||||
tools=list(_copy_tools(tools)),
|
||||
stream=stream,
|
||||
think=think,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
format=format,
|
||||
options=options,
|
||||
keep_alive=keep_alive,
|
||||
@ -802,6 +818,8 @@ class AsyncClient(BaseClient):
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: Literal[False] = False,
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
raw: bool = False,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||
@ -821,6 +839,8 @@ class AsyncClient(BaseClient):
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: Literal[True] = True,
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
raw: bool = False,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||
@ -839,6 +859,8 @@ class AsyncClient(BaseClient):
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: bool = False,
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
raw: Optional[bool] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||
@ -867,6 +889,8 @@ class AsyncClient(BaseClient):
|
||||
context=context,
|
||||
stream=stream,
|
||||
think=think,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
raw=raw,
|
||||
format=format,
|
||||
images=list(_copy_images(images)) if images else None,
|
||||
@ -885,6 +909,8 @@ class AsyncClient(BaseClient):
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||
stream: Literal[False] = False,
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
@ -899,6 +925,8 @@ class AsyncClient(BaseClient):
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||
stream: Literal[True] = True,
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
@ -912,6 +940,8 @@ class AsyncClient(BaseClient):
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||
stream: bool = False,
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
@ -960,6 +990,8 @@ class AsyncClient(BaseClient):
|
||||
tools=list(_copy_tools(tools)),
|
||||
stream=stream,
|
||||
think=think,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
format=format,
|
||||
options=options,
|
||||
keep_alive=keep_alive,
|
||||
|
||||
@ -210,6 +210,12 @@ class GenerateRequest(BaseGenerateRequest):
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None
|
||||
'Enable thinking mode (for thinking models).'
|
||||
|
||||
logprobs: Optional[bool] = None
|
||||
'Return log probabilities for generated tokens.'
|
||||
|
||||
top_logprobs: Optional[int] = None
|
||||
'Number of alternative tokens and log probabilities to include per position (0-20).'
|
||||
|
||||
|
||||
class BaseGenerateResponse(SubscriptableBaseModel):
|
||||
model: Optional[str] = None
|
||||
@ -243,6 +249,19 @@ class BaseGenerateResponse(SubscriptableBaseModel):
|
||||
'Duration of evaluating inference in nanoseconds.'
|
||||
|
||||
|
||||
class TokenLogprob(SubscriptableBaseModel):
|
||||
token: str
|
||||
'Token text.'
|
||||
|
||||
logprob: float
|
||||
'Log probability for the token.'
|
||||
|
||||
|
||||
class Logprob(TokenLogprob):
|
||||
top_logprobs: Optional[Sequence[TokenLogprob]] = None
|
||||
'Most likely tokens and their log probabilities.'
|
||||
|
||||
|
||||
class GenerateResponse(BaseGenerateResponse):
|
||||
"""
|
||||
Response returned by generate requests.
|
||||
@ -257,6 +276,9 @@ class GenerateResponse(BaseGenerateResponse):
|
||||
context: Optional[Sequence[int]] = None
|
||||
'Tokenized history up to the point of the response.'
|
||||
|
||||
logprobs: Optional[Sequence[Logprob]] = None
|
||||
'Log probabilities for generated tokens.'
|
||||
|
||||
|
||||
class Message(SubscriptableBaseModel):
|
||||
"""
|
||||
@ -360,6 +382,12 @@ class ChatRequest(BaseGenerateRequest):
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None
|
||||
'Enable thinking mode (for thinking models).'
|
||||
|
||||
logprobs: Optional[bool] = None
|
||||
'Return log probabilities for generated tokens.'
|
||||
|
||||
top_logprobs: Optional[int] = None
|
||||
'Number of alternative tokens and log probabilities to include per position (0-20).'
|
||||
|
||||
|
||||
class ChatResponse(BaseGenerateResponse):
|
||||
"""
|
||||
@ -369,6 +397,9 @@ class ChatResponse(BaseGenerateResponse):
|
||||
message: Message
|
||||
'Response message.'
|
||||
|
||||
logprobs: Optional[Sequence[Logprob]] = None
|
||||
'Log probabilities for generated tokens if requested.'
|
||||
|
||||
|
||||
class EmbedRequest(BaseRequest):
|
||||
input: Union[str, Sequence[str]]
|
||||
|
||||
@ -37,7 +37,7 @@ dependencies = [ 'ruff>=0.9.1' ]
|
||||
config-path = 'none'
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 999
|
||||
line-length = 320
|
||||
indent-width = 2
|
||||
|
||||
[tool.ruff.format]
|
||||
|
||||
@ -61,6 +61,44 @@ def test_client_chat(httpserver: HTTPServer):
|
||||
assert response['message']['content'] == "I don't know."
|
||||
|
||||
|
||||
def test_client_chat_with_logprobs(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/chat',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Hi'}],
|
||||
'tools': [],
|
||||
'stream': False,
|
||||
'logprobs': True,
|
||||
'top_logprobs': 3,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': 'Hello',
|
||||
},
|
||||
'logprobs': [
|
||||
{
|
||||
'token': 'Hello',
|
||||
'logprob': -0.1,
|
||||
'top_logprobs': [
|
||||
{'token': 'Hello', 'logprob': -0.1},
|
||||
{'token': 'Hi', 'logprob': -1.0},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Hi'}], logprobs=True, top_logprobs=3)
|
||||
assert response['logprobs'][0]['token'] == 'Hello'
|
||||
assert response['logprobs'][0]['top_logprobs'][1]['token'] == 'Hi'
|
||||
|
||||
|
||||
def test_client_chat_stream(httpserver: HTTPServer):
|
||||
def stream_handler(_: Request):
|
||||
def generate():
|
||||
@ -294,6 +332,40 @@ def test_client_generate(httpserver: HTTPServer):
|
||||
assert response['response'] == 'Because it is.'
|
||||
|
||||
|
||||
def test_client_generate_with_logprobs(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why',
|
||||
'stream': False,
|
||||
'logprobs': True,
|
||||
'top_logprobs': 2,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'response': 'Hello',
|
||||
'logprobs': [
|
||||
{
|
||||
'token': 'Hello',
|
||||
'logprob': -0.2,
|
||||
'top_logprobs': [
|
||||
{'token': 'Hello', 'logprob': -0.2},
|
||||
{'token': 'Hi', 'logprob': -1.5},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.generate('dummy', 'Why', logprobs=True, top_logprobs=2)
|
||||
assert response['logprobs'][0]['token'] == 'Hello'
|
||||
assert response['logprobs'][0]['top_logprobs'][1]['token'] == 'Hi'
|
||||
|
||||
|
||||
def test_client_generate_with_image_type(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
|
||||
Loading…
Reference in New Issue
Block a user