add generate and chat responses

This commit is contained in:
Michael Yang 2024-01-09 17:17:44 -08:00
parent 0f9211de97
commit cea391a041
3 changed files with 79 additions and 11 deletions

View File

@ -1,9 +1,16 @@
from ollama._client import Client, AsyncClient
from ollama._types import Message, Options
from ollama._types import (
GenerateResponse,
ChatResponse,
Message,
Options,
)
__all__ = [
'Client',
'AsyncClient',
'GenerateResponse',
'ChatResponse',
'Message',
'Options',
'generate',

View File

@ -7,7 +7,7 @@ from pathlib import Path
from hashlib import sha256
from base64 import b64encode
from typing import Any, AnyStr, Union, Optional, List, Mapping
from typing import Any, AnyStr, Union, Optional, Sequence, Mapping
import sys
@ -56,13 +56,17 @@ class Client(BaseClient):
prompt: str = '',
system: str = '',
template: str = '',
context: Optional[List[int]] = None,
context: Optional[Sequence[int]] = None,
stream: bool = False,
raw: bool = False,
format: str = '',
images: Optional[List[AnyStr]] = None,
images: Optional[Sequence[AnyStr]] = None,
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
"""
Returns `GenerateResponse` if `stream` is `False`, otherwise returns a `GenerateResponse` generator.
"""
if not model:
raise Exception('must provide a model')
@ -87,11 +91,15 @@ class Client(BaseClient):
def chat(
self,
model: str = '',
messages: Optional[List[Message]] = None,
messages: Optional[Sequence[Message]] = None,
stream: bool = False,
format: str = '',
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
"""
Returns `ChatResponse` if `stream` is `False`, otherwise returns a `ChatResponse` generator.
"""
if not model:
raise Exception('must provide a model')
@ -124,6 +132,9 @@ class Client(BaseClient):
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
"""
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
return self._request_stream(
'POST',
'/api/pull',
@ -141,6 +152,9 @@ class Client(BaseClient):
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
"""
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
return self._request_stream(
'POST',
'/api/push',
@ -159,6 +173,9 @@ class Client(BaseClient):
modelfile: Optional[str] = None,
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
"""
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
if (realpath := _as_path(path)) and realpath.exists():
modelfile = self._parse_modelfile(realpath.read_text(), base=realpath.parent)
elif modelfile:
@ -267,13 +284,16 @@ class AsyncClient(BaseClient):
prompt: str = '',
system: str = '',
template: str = '',
context: Optional[List[int]] = None,
context: Optional[Sequence[int]] = None,
stream: bool = False,
raw: bool = False,
format: str = '',
images: Optional[List[AnyStr]] = None,
images: Optional[Sequence[AnyStr]] = None,
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
"""
Returns `GenerateResponse` if `stream` is `False`, otherwise returns an asynchronous `GenerateResponse` generator.
"""
if not model:
raise Exception('must provide a model')
@ -298,11 +318,14 @@ class AsyncClient(BaseClient):
async def chat(
self,
model: str = '',
messages: Optional[List[Message]] = None,
messages: Optional[Sequence[Message]] = None,
stream: bool = False,
format: str = '',
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
"""
Returns `ChatResponse` if `stream` is `False`, otherwise returns an asynchronous `ChatResponse` generator.
"""
if not model:
raise Exception('must provide a model')
@ -335,6 +358,9 @@ class AsyncClient(BaseClient):
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
"""
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
return await self._request_stream(
'POST',
'/api/pull',
@ -352,6 +378,9 @@ class AsyncClient(BaseClient):
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
"""
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
return await self._request_stream(
'POST',
'/api/push',
@ -370,6 +399,9 @@ class AsyncClient(BaseClient):
modelfile: Optional[str] = None,
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
"""
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
if (realpath := _as_path(path)) and realpath.exists():
modelfile = await self._parse_modelfile(realpath.read_text(), base=realpath.parent)
elif modelfile:

View File

@ -1,4 +1,4 @@
from typing import Any, TypedDict, List
from typing import Any, TypedDict, Sequence
import sys
@ -8,10 +8,39 @@ else:
from typing import NotRequired
class BaseGenerateResponse(TypedDict):
model: str
created_at: str
done: bool
total_duration: int
load_duration: int
prompt_eval_count: int
prompt_eval_duration: int
eval_count: int
eval_duration: int
class GenerateResponse(BaseGenerateResponse):
response: str
context: Sequence[int]
class Message(TypedDict):
role: str
content: str
images: NotRequired[List[Any]]
images: NotRequired[Sequence[Any]]
class ChatResponse(BaseGenerateResponse):
message: Message
class ProgressResponse(TypedDict):
status: str
completed: int
total: int
digest: str
class Options(TypedDict, total=False):
@ -50,4 +79,4 @@ class Options(TypedDict, total=False):
mirostat_tau: float
mirostat_eta: float
penalize_newline: bool
stop: List[str]
stop: Sequence[str]