mirror of
https://github.com/ollama/ollama-python.git
synced 2026-01-13 21:57:16 +08:00
client/types: add logprobs support (#601)
This commit is contained in:
parent
9ddd5f0182
commit
0008226fda
31
examples/chat-logprobs.py
Normal file
31
examples/chat-logprobs.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
import ollama
|
||||||
|
|
||||||
|
|
||||||
|
def print_logprobs(logprobs: Iterable[dict], label: str) -> None:
|
||||||
|
print(f'\n{label}:')
|
||||||
|
for entry in logprobs:
|
||||||
|
token = entry.get('token', '')
|
||||||
|
logprob = entry.get('logprob')
|
||||||
|
print(f' token={token!r:<12} logprob={logprob:.3f}')
|
||||||
|
for alt in entry.get('top_logprobs', []):
|
||||||
|
if alt['token'] != token:
|
||||||
|
print(f' alt -> {alt["token"]!r:<12} ({alt["logprob"]:.3f})')
|
||||||
|
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
'role': 'user',
|
||||||
|
'content': 'hi! be concise.',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = ollama.chat(
|
||||||
|
model='gemma3',
|
||||||
|
messages=messages,
|
||||||
|
logprobs=True,
|
||||||
|
top_logprobs=3,
|
||||||
|
)
|
||||||
|
print('Chat response:', response['message']['content'])
|
||||||
|
print_logprobs(response.get('logprobs', []), 'chat logprobs')
|
||||||
@ -15,7 +15,8 @@ messages = [
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
'content': 'The weather in Tokyo is typically warm and humid during the summer months, with temperatures often exceeding 30°C (86°F). The city experiences a rainy season from June to September, with heavy rainfall and occasional typhoons. Winter is mild, with temperatures rarely dropping below freezing. The city is known for its high-tech and vibrant culture, with many popular tourist attractions such as the Tokyo Tower, Senso-ji Temple, and the bustling Shibuya district.',
|
'content': """The weather in Tokyo is typically warm and humid during the summer months, with temperatures often exceeding 30°C (86°F). The city experiences a rainy season from June to September, with heavy rainfall and occasional typhoons. Winter is mild, with temperatures
|
||||||
|
rarely dropping below freezing. The city is known for its high-tech and vibrant culture, with many popular tourist attractions such as the Tokyo Tower, Senso-ji Temple, and the bustling Shibuya district.""",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
24
examples/generate-logprobs.py
Normal file
24
examples/generate-logprobs.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
import ollama
|
||||||
|
|
||||||
|
|
||||||
|
def print_logprobs(logprobs: Iterable[dict], label: str) -> None:
|
||||||
|
print(f'\n{label}:')
|
||||||
|
for entry in logprobs:
|
||||||
|
token = entry.get('token', '')
|
||||||
|
logprob = entry.get('logprob')
|
||||||
|
print(f' token={token!r:<12} logprob={logprob:.3f}')
|
||||||
|
for alt in entry.get('top_logprobs', []):
|
||||||
|
if alt['token'] != token:
|
||||||
|
print(f' alt -> {alt["token"]!r:<12} ({alt["logprob"]:.3f})')
|
||||||
|
|
||||||
|
|
||||||
|
response = ollama.generate(
|
||||||
|
model='gemma3',
|
||||||
|
prompt='hi! be concise.',
|
||||||
|
logprobs=True,
|
||||||
|
top_logprobs=3,
|
||||||
|
)
|
||||||
|
print('Generate response:', response['response'])
|
||||||
|
print_logprobs(response.get('logprobs', []), 'generate logprobs')
|
||||||
@ -200,6 +200,8 @@ class Client(BaseClient):
|
|||||||
context: Optional[Sequence[int]] = None,
|
context: Optional[Sequence[int]] = None,
|
||||||
stream: Literal[False] = False,
|
stream: Literal[False] = False,
|
||||||
think: Optional[bool] = None,
|
think: Optional[bool] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
raw: bool = False,
|
raw: bool = False,
|
||||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||||
@ -219,6 +221,8 @@ class Client(BaseClient):
|
|||||||
context: Optional[Sequence[int]] = None,
|
context: Optional[Sequence[int]] = None,
|
||||||
stream: Literal[True] = True,
|
stream: Literal[True] = True,
|
||||||
think: Optional[bool] = None,
|
think: Optional[bool] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
raw: bool = False,
|
raw: bool = False,
|
||||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||||
@ -237,6 +241,8 @@ class Client(BaseClient):
|
|||||||
context: Optional[Sequence[int]] = None,
|
context: Optional[Sequence[int]] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
think: Optional[bool] = None,
|
think: Optional[bool] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
raw: Optional[bool] = None,
|
raw: Optional[bool] = None,
|
||||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||||
@ -266,6 +272,8 @@ class Client(BaseClient):
|
|||||||
context=context,
|
context=context,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
think=think,
|
think=think,
|
||||||
|
logprobs=logprobs,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
raw=raw,
|
raw=raw,
|
||||||
format=format,
|
format=format,
|
||||||
images=list(_copy_images(images)) if images else None,
|
images=list(_copy_images(images)) if images else None,
|
||||||
@ -284,6 +292,8 @@ class Client(BaseClient):
|
|||||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||||
stream: Literal[False] = False,
|
stream: Literal[False] = False,
|
||||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
@ -298,6 +308,8 @@ class Client(BaseClient):
|
|||||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||||
stream: Literal[True] = True,
|
stream: Literal[True] = True,
|
||||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
@ -311,6 +323,8 @@ class Client(BaseClient):
|
|||||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
@ -358,6 +372,8 @@ class Client(BaseClient):
|
|||||||
tools=list(_copy_tools(tools)),
|
tools=list(_copy_tools(tools)),
|
||||||
stream=stream,
|
stream=stream,
|
||||||
think=think,
|
think=think,
|
||||||
|
logprobs=logprobs,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
format=format,
|
format=format,
|
||||||
options=options,
|
options=options,
|
||||||
keep_alive=keep_alive,
|
keep_alive=keep_alive,
|
||||||
@ -802,6 +818,8 @@ class AsyncClient(BaseClient):
|
|||||||
context: Optional[Sequence[int]] = None,
|
context: Optional[Sequence[int]] = None,
|
||||||
stream: Literal[False] = False,
|
stream: Literal[False] = False,
|
||||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
raw: bool = False,
|
raw: bool = False,
|
||||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||||
@ -821,6 +839,8 @@ class AsyncClient(BaseClient):
|
|||||||
context: Optional[Sequence[int]] = None,
|
context: Optional[Sequence[int]] = None,
|
||||||
stream: Literal[True] = True,
|
stream: Literal[True] = True,
|
||||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
raw: bool = False,
|
raw: bool = False,
|
||||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||||
@ -839,6 +859,8 @@ class AsyncClient(BaseClient):
|
|||||||
context: Optional[Sequence[int]] = None,
|
context: Optional[Sequence[int]] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
raw: Optional[bool] = None,
|
raw: Optional[bool] = None,
|
||||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||||
@ -867,6 +889,8 @@ class AsyncClient(BaseClient):
|
|||||||
context=context,
|
context=context,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
think=think,
|
think=think,
|
||||||
|
logprobs=logprobs,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
raw=raw,
|
raw=raw,
|
||||||
format=format,
|
format=format,
|
||||||
images=list(_copy_images(images)) if images else None,
|
images=list(_copy_images(images)) if images else None,
|
||||||
@ -885,6 +909,8 @@ class AsyncClient(BaseClient):
|
|||||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||||
stream: Literal[False] = False,
|
stream: Literal[False] = False,
|
||||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
@ -899,6 +925,8 @@ class AsyncClient(BaseClient):
|
|||||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||||
stream: Literal[True] = True,
|
stream: Literal[True] = True,
|
||||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
@ -912,6 +940,8 @@ class AsyncClient(BaseClient):
|
|||||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
@ -960,6 +990,8 @@ class AsyncClient(BaseClient):
|
|||||||
tools=list(_copy_tools(tools)),
|
tools=list(_copy_tools(tools)),
|
||||||
stream=stream,
|
stream=stream,
|
||||||
think=think,
|
think=think,
|
||||||
|
logprobs=logprobs,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
format=format,
|
format=format,
|
||||||
options=options,
|
options=options,
|
||||||
keep_alive=keep_alive,
|
keep_alive=keep_alive,
|
||||||
|
|||||||
@ -210,6 +210,12 @@ class GenerateRequest(BaseGenerateRequest):
|
|||||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None
|
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None
|
||||||
'Enable thinking mode (for thinking models).'
|
'Enable thinking mode (for thinking models).'
|
||||||
|
|
||||||
|
logprobs: Optional[bool] = None
|
||||||
|
'Return log probabilities for generated tokens.'
|
||||||
|
|
||||||
|
top_logprobs: Optional[int] = None
|
||||||
|
'Number of alternative tokens and log probabilities to include per position (0-20).'
|
||||||
|
|
||||||
|
|
||||||
class BaseGenerateResponse(SubscriptableBaseModel):
|
class BaseGenerateResponse(SubscriptableBaseModel):
|
||||||
model: Optional[str] = None
|
model: Optional[str] = None
|
||||||
@ -243,6 +249,19 @@ class BaseGenerateResponse(SubscriptableBaseModel):
|
|||||||
'Duration of evaluating inference in nanoseconds.'
|
'Duration of evaluating inference in nanoseconds.'
|
||||||
|
|
||||||
|
|
||||||
|
class TokenLogprob(SubscriptableBaseModel):
|
||||||
|
token: str
|
||||||
|
'Token text.'
|
||||||
|
|
||||||
|
logprob: float
|
||||||
|
'Log probability for the token.'
|
||||||
|
|
||||||
|
|
||||||
|
class Logprob(TokenLogprob):
|
||||||
|
top_logprobs: Optional[Sequence[TokenLogprob]] = None
|
||||||
|
'Most likely tokens and their log probabilities.'
|
||||||
|
|
||||||
|
|
||||||
class GenerateResponse(BaseGenerateResponse):
|
class GenerateResponse(BaseGenerateResponse):
|
||||||
"""
|
"""
|
||||||
Response returned by generate requests.
|
Response returned by generate requests.
|
||||||
@ -257,6 +276,9 @@ class GenerateResponse(BaseGenerateResponse):
|
|||||||
context: Optional[Sequence[int]] = None
|
context: Optional[Sequence[int]] = None
|
||||||
'Tokenized history up to the point of the response.'
|
'Tokenized history up to the point of the response.'
|
||||||
|
|
||||||
|
logprobs: Optional[Sequence[Logprob]] = None
|
||||||
|
'Log probabilities for generated tokens.'
|
||||||
|
|
||||||
|
|
||||||
class Message(SubscriptableBaseModel):
|
class Message(SubscriptableBaseModel):
|
||||||
"""
|
"""
|
||||||
@ -360,6 +382,12 @@ class ChatRequest(BaseGenerateRequest):
|
|||||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None
|
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None
|
||||||
'Enable thinking mode (for thinking models).'
|
'Enable thinking mode (for thinking models).'
|
||||||
|
|
||||||
|
logprobs: Optional[bool] = None
|
||||||
|
'Return log probabilities for generated tokens.'
|
||||||
|
|
||||||
|
top_logprobs: Optional[int] = None
|
||||||
|
'Number of alternative tokens and log probabilities to include per position (0-20).'
|
||||||
|
|
||||||
|
|
||||||
class ChatResponse(BaseGenerateResponse):
|
class ChatResponse(BaseGenerateResponse):
|
||||||
"""
|
"""
|
||||||
@ -369,6 +397,9 @@ class ChatResponse(BaseGenerateResponse):
|
|||||||
message: Message
|
message: Message
|
||||||
'Response message.'
|
'Response message.'
|
||||||
|
|
||||||
|
logprobs: Optional[Sequence[Logprob]] = None
|
||||||
|
'Log probabilities for generated tokens if requested.'
|
||||||
|
|
||||||
|
|
||||||
class EmbedRequest(BaseRequest):
|
class EmbedRequest(BaseRequest):
|
||||||
input: Union[str, Sequence[str]]
|
input: Union[str, Sequence[str]]
|
||||||
|
|||||||
@ -37,7 +37,7 @@ dependencies = [ 'ruff>=0.9.1' ]
|
|||||||
config-path = 'none'
|
config-path = 'none'
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 999
|
line-length = 320
|
||||||
indent-width = 2
|
indent-width = 2
|
||||||
|
|
||||||
[tool.ruff.format]
|
[tool.ruff.format]
|
||||||
|
|||||||
@ -61,6 +61,44 @@ def test_client_chat(httpserver: HTTPServer):
|
|||||||
assert response['message']['content'] == "I don't know."
|
assert response['message']['content'] == "I don't know."
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_chat_with_logprobs(httpserver: HTTPServer):
|
||||||
|
httpserver.expect_ordered_request(
|
||||||
|
'/api/chat',
|
||||||
|
method='POST',
|
||||||
|
json={
|
||||||
|
'model': 'dummy',
|
||||||
|
'messages': [{'role': 'user', 'content': 'Hi'}],
|
||||||
|
'tools': [],
|
||||||
|
'stream': False,
|
||||||
|
'logprobs': True,
|
||||||
|
'top_logprobs': 3,
|
||||||
|
},
|
||||||
|
).respond_with_json(
|
||||||
|
{
|
||||||
|
'model': 'dummy',
|
||||||
|
'message': {
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': 'Hello',
|
||||||
|
},
|
||||||
|
'logprobs': [
|
||||||
|
{
|
||||||
|
'token': 'Hello',
|
||||||
|
'logprob': -0.1,
|
||||||
|
'top_logprobs': [
|
||||||
|
{'token': 'Hello', 'logprob': -0.1},
|
||||||
|
{'token': 'Hi', 'logprob': -1.0},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
client = Client(httpserver.url_for('/'))
|
||||||
|
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Hi'}], logprobs=True, top_logprobs=3)
|
||||||
|
assert response['logprobs'][0]['token'] == 'Hello'
|
||||||
|
assert response['logprobs'][0]['top_logprobs'][1]['token'] == 'Hi'
|
||||||
|
|
||||||
|
|
||||||
def test_client_chat_stream(httpserver: HTTPServer):
|
def test_client_chat_stream(httpserver: HTTPServer):
|
||||||
def stream_handler(_: Request):
|
def stream_handler(_: Request):
|
||||||
def generate():
|
def generate():
|
||||||
@ -294,6 +332,40 @@ def test_client_generate(httpserver: HTTPServer):
|
|||||||
assert response['response'] == 'Because it is.'
|
assert response['response'] == 'Because it is.'
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_generate_with_logprobs(httpserver: HTTPServer):
|
||||||
|
httpserver.expect_ordered_request(
|
||||||
|
'/api/generate',
|
||||||
|
method='POST',
|
||||||
|
json={
|
||||||
|
'model': 'dummy',
|
||||||
|
'prompt': 'Why',
|
||||||
|
'stream': False,
|
||||||
|
'logprobs': True,
|
||||||
|
'top_logprobs': 2,
|
||||||
|
},
|
||||||
|
).respond_with_json(
|
||||||
|
{
|
||||||
|
'model': 'dummy',
|
||||||
|
'response': 'Hello',
|
||||||
|
'logprobs': [
|
||||||
|
{
|
||||||
|
'token': 'Hello',
|
||||||
|
'logprob': -0.2,
|
||||||
|
'top_logprobs': [
|
||||||
|
{'token': 'Hello', 'logprob': -0.2},
|
||||||
|
{'token': 'Hi', 'logprob': -1.5},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
client = Client(httpserver.url_for('/'))
|
||||||
|
response = client.generate('dummy', 'Why', logprobs=True, top_logprobs=2)
|
||||||
|
assert response['logprobs'][0]['token'] == 'Hello'
|
||||||
|
assert response['logprobs'][0]['top_logprobs'][1]['token'] == 'Hi'
|
||||||
|
|
||||||
|
|
||||||
def test_client_generate_with_image_type(httpserver: HTTPServer):
|
def test_client_generate_with_image_type(httpserver: HTTPServer):
|
||||||
httpserver.expect_ordered_request(
|
httpserver.expect_ordered_request(
|
||||||
'/api/generate',
|
'/api/generate',
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user