mirror of
https://github.com/ollama/ollama-python.git
synced 2026-05-03 12:52:35 +00:00
client/types: add logprobs support (#601)
This commit is contained in:
@@ -61,6 +61,44 @@ def test_client_chat(httpserver: HTTPServer):
|
||||
assert response['message']['content'] == "I don't know."
|
||||
|
||||
|
||||
def test_client_chat_with_logprobs(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/chat',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Hi'}],
|
||||
'tools': [],
|
||||
'stream': False,
|
||||
'logprobs': True,
|
||||
'top_logprobs': 3,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': 'Hello',
|
||||
},
|
||||
'logprobs': [
|
||||
{
|
||||
'token': 'Hello',
|
||||
'logprob': -0.1,
|
||||
'top_logprobs': [
|
||||
{'token': 'Hello', 'logprob': -0.1},
|
||||
{'token': 'Hi', 'logprob': -1.0},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Hi'}], logprobs=True, top_logprobs=3)
|
||||
assert response['logprobs'][0]['token'] == 'Hello'
|
||||
assert response['logprobs'][0]['top_logprobs'][1]['token'] == 'Hi'
|
||||
|
||||
|
||||
def test_client_chat_stream(httpserver: HTTPServer):
|
||||
def stream_handler(_: Request):
|
||||
def generate():
|
||||
@@ -294,6 +332,40 @@ def test_client_generate(httpserver: HTTPServer):
|
||||
assert response['response'] == 'Because it is.'
|
||||
|
||||
|
||||
def test_client_generate_with_logprobs(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why',
|
||||
'stream': False,
|
||||
'logprobs': True,
|
||||
'top_logprobs': 2,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'response': 'Hello',
|
||||
'logprobs': [
|
||||
{
|
||||
'token': 'Hello',
|
||||
'logprob': -0.2,
|
||||
'top_logprobs': [
|
||||
{'token': 'Hello', 'logprob': -0.2},
|
||||
{'token': 'Hi', 'logprob': -1.5},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.generate('dummy', 'Why', logprobs=True, top_logprobs=2)
|
||||
assert response['logprobs'][0]['token'] == 'Hello'
|
||||
assert response['logprobs'][0]['top_logprobs'][1]['token'] == 'Hi'
|
||||
|
||||
|
||||
def test_client_generate_with_image_type(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
|
||||
Reference in New Issue
Block a user