mirror of
https://github.com/ollama/ollama-python.git
synced 2026-01-13 21:57:16 +08:00
add generate and chat responses
This commit is contained in:
parent
0f9211de97
commit
cea391a041
@ -1,9 +1,16 @@
|
||||
from ollama._client import Client, AsyncClient
|
||||
from ollama._types import Message, Options
|
||||
from ollama._types import (
|
||||
GenerateResponse,
|
||||
ChatResponse,
|
||||
Message,
|
||||
Options,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'Client',
|
||||
'AsyncClient',
|
||||
'GenerateResponse',
|
||||
'ChatResponse',
|
||||
'Message',
|
||||
'Options',
|
||||
'generate',
|
||||
|
||||
@ -7,7 +7,7 @@ from pathlib import Path
|
||||
from hashlib import sha256
|
||||
from base64 import b64encode
|
||||
|
||||
from typing import Any, AnyStr, Union, Optional, List, Mapping
|
||||
from typing import Any, AnyStr, Union, Optional, Sequence, Mapping
|
||||
|
||||
import sys
|
||||
|
||||
@ -56,13 +56,17 @@ class Client(BaseClient):
|
||||
prompt: str = '',
|
||||
system: str = '',
|
||||
template: str = '',
|
||||
context: Optional[List[int]] = None,
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: bool = False,
|
||||
raw: bool = False,
|
||||
format: str = '',
|
||||
images: Optional[List[AnyStr]] = None,
|
||||
images: Optional[Sequence[AnyStr]] = None,
|
||||
options: Optional[Options] = None,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
"""
|
||||
Returns `GenerateResponse` if `stream` is `False`, otherwise returns a `GenerateResponse` generator.
|
||||
"""
|
||||
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
@ -87,11 +91,15 @@ class Client(BaseClient):
|
||||
def chat(
|
||||
self,
|
||||
model: str = '',
|
||||
messages: Optional[List[Message]] = None,
|
||||
messages: Optional[Sequence[Message]] = None,
|
||||
stream: bool = False,
|
||||
format: str = '',
|
||||
options: Optional[Options] = None,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
"""
|
||||
Returns `ChatResponse` if `stream` is `False`, otherwise returns a `ChatResponse` generator.
|
||||
"""
|
||||
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
@ -124,6 +132,9 @@ class Client(BaseClient):
|
||||
insecure: bool = False,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
"""
|
||||
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
|
||||
"""
|
||||
return self._request_stream(
|
||||
'POST',
|
||||
'/api/pull',
|
||||
@ -141,6 +152,9 @@ class Client(BaseClient):
|
||||
insecure: bool = False,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
"""
|
||||
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
|
||||
"""
|
||||
return self._request_stream(
|
||||
'POST',
|
||||
'/api/push',
|
||||
@ -159,6 +173,9 @@ class Client(BaseClient):
|
||||
modelfile: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
"""
|
||||
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
|
||||
"""
|
||||
if (realpath := _as_path(path)) and realpath.exists():
|
||||
modelfile = self._parse_modelfile(realpath.read_text(), base=realpath.parent)
|
||||
elif modelfile:
|
||||
@ -267,13 +284,16 @@ class AsyncClient(BaseClient):
|
||||
prompt: str = '',
|
||||
system: str = '',
|
||||
template: str = '',
|
||||
context: Optional[List[int]] = None,
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: bool = False,
|
||||
raw: bool = False,
|
||||
format: str = '',
|
||||
images: Optional[List[AnyStr]] = None,
|
||||
images: Optional[Sequence[AnyStr]] = None,
|
||||
options: Optional[Options] = None,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
"""
|
||||
Returns `GenerateResponse` if `stream` is `False`, otherwise returns an asynchronous `GenerateResponse` generator.
|
||||
"""
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
@ -298,11 +318,14 @@ class AsyncClient(BaseClient):
|
||||
async def chat(
|
||||
self,
|
||||
model: str = '',
|
||||
messages: Optional[List[Message]] = None,
|
||||
messages: Optional[Sequence[Message]] = None,
|
||||
stream: bool = False,
|
||||
format: str = '',
|
||||
options: Optional[Options] = None,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
"""
|
||||
Returns `ChatResponse` if `stream` is `False`, otherwise returns an asynchronous `ChatResponse` generator.
|
||||
"""
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
@ -335,6 +358,9 @@ class AsyncClient(BaseClient):
|
||||
insecure: bool = False,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
"""
|
||||
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
|
||||
"""
|
||||
return await self._request_stream(
|
||||
'POST',
|
||||
'/api/pull',
|
||||
@ -352,6 +378,9 @@ class AsyncClient(BaseClient):
|
||||
insecure: bool = False,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
"""
|
||||
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
|
||||
"""
|
||||
return await self._request_stream(
|
||||
'POST',
|
||||
'/api/push',
|
||||
@ -370,6 +399,9 @@ class AsyncClient(BaseClient):
|
||||
modelfile: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
"""
|
||||
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
|
||||
"""
|
||||
if (realpath := _as_path(path)) and realpath.exists():
|
||||
modelfile = await self._parse_modelfile(realpath.read_text(), base=realpath.parent)
|
||||
elif modelfile:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, TypedDict, List
|
||||
from typing import Any, TypedDict, Sequence
|
||||
|
||||
import sys
|
||||
|
||||
@ -8,10 +8,39 @@ else:
|
||||
from typing import NotRequired
|
||||
|
||||
|
||||
class BaseGenerateResponse(TypedDict):
|
||||
model: str
|
||||
created_at: str
|
||||
done: bool
|
||||
|
||||
total_duration: int
|
||||
load_duration: int
|
||||
prompt_eval_count: int
|
||||
prompt_eval_duration: int
|
||||
eval_count: int
|
||||
eval_duration: int
|
||||
|
||||
|
||||
class GenerateResponse(BaseGenerateResponse):
|
||||
response: str
|
||||
context: Sequence[int]
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
role: str
|
||||
content: str
|
||||
images: NotRequired[List[Any]]
|
||||
images: NotRequired[Sequence[Any]]
|
||||
|
||||
|
||||
class ChatResponse(BaseGenerateResponse):
|
||||
message: Message
|
||||
|
||||
|
||||
class ProgressResponse(TypedDict):
|
||||
status: str
|
||||
completed: int
|
||||
total: int
|
||||
digest: str
|
||||
|
||||
|
||||
class Options(TypedDict, total=False):
|
||||
@ -50,4 +79,4 @@ class Options(TypedDict, total=False):
|
||||
mirostat_tau: float
|
||||
mirostat_eta: float
|
||||
penalize_newline: bool
|
||||
stop: List[str]
|
||||
stop: Sequence[str]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user