Merge pull request #5 from jmorganca/mxyng

Mxyng
This commit is contained in:
Michael Yang 2024-01-11 09:53:42 -08:00 committed by GitHub
commit efc7ea3e78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 319 additions and 46 deletions

View File

@ -68,3 +68,18 @@ async def chat():
asyncio.run(chat())
```
## Handling Errors
Errors are raised if requests return an error status or if an error is detected while streaming.
```python
model = 'does-not-yet-exist'
try:
ollama.chat(model)
except ollama.ResponseError as e:
print('Error:', e.content)
if e.status_code == 404:
ollama.pull(model)
```

View File

@ -1,13 +1,27 @@
from ollama._client import Client, AsyncClient
from ollama._types import Message, Options
from ollama._types import (
GenerateResponse,
ChatResponse,
ProgressResponse,
Message,
Options,
RequestError,
ResponseError,
)
__all__ = [
'Client',
'AsyncClient',
'GenerateResponse',
'ChatResponse',
'ProgressResponse',
'Message',
'Options',
'RequestError',
'ResponseError',
'generate',
'chat',
'embeddings',
'pull',
'push',
'create',
@ -21,6 +35,7 @@ _client = Client()
generate = _client.generate
chat = _client.chat
embeddings = _client.embeddings
pull = _client.pull
push = _client.push
create = _client.create

View File

@ -7,7 +7,7 @@ from pathlib import Path
from hashlib import sha256
from base64 import b64encode
from typing import Any, AnyStr, Union, Optional, List, Mapping
from typing import Any, AnyStr, Union, Optional, Sequence, Mapping, Literal
import sys
@ -16,30 +16,62 @@ if sys.version_info < (3, 9):
else:
from collections.abc import Iterator, AsyncIterator
from ollama._types import Message, Options
from ollama._types import Message, Options, RequestError, ResponseError
class BaseClient:
def __init__(self, client, base_url: Optional[str] = None) -> None:
base_url = base_url or os.getenv('OLLAMA_HOST', 'http://127.0.0.1:11434')
self._client = client(base_url=base_url, follow_redirects=True, timeout=None)
def __init__(
self,
client,
base_url: Optional[str] = None,
follow_redirects: bool = True,
timeout: Any = None,
**kwargs,
) -> None:
"""
Creates a httpx client. Default parameters are the same as those defined in httpx
except for the following:
- `base_url`: http://127.0.0.1:11434
- `follow_redirects`: True
- `timeout`: None
`kwargs` are passed to the httpx client.
"""
self._client = client(
base_url=base_url or os.getenv('OLLAMA_HOST', 'http://127.0.0.1:11434'),
follow_redirects=follow_redirects,
timeout=timeout,
**kwargs,
)
class Client(BaseClient):
def __init__(self, base_url: Optional[str] = None) -> None:
super().__init__(httpx.Client, base_url)
def __init__(self, base_url: Optional[str] = None, **kwargs) -> None:
super().__init__(httpx.Client, base_url, **kwargs)
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
response = self._client.request(method, url, **kwargs)
response.raise_for_status()
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise ResponseError(e.response.text, e.response.status_code) from None
return response
def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any]]:
with self._client.stream(method, url, **kwargs) as r:
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
e.response.read()
raise ResponseError(e.response.text, e.response.status_code) from None
for line in r.iter_lines():
partial = json.loads(line)
if e := partial.get('error'):
raise Exception(e)
raise ResponseError(e)
yield partial
def _request_stream(
@ -56,15 +88,25 @@ class Client(BaseClient):
prompt: str = '',
system: str = '',
template: str = '',
context: Optional[List[int]] = None,
context: Optional[Sequence[int]] = None,
stream: bool = False,
raw: bool = False,
format: str = '',
images: Optional[List[AnyStr]] = None,
format: Literal['', 'json'] = '',
images: Optional[Sequence[AnyStr]] = None,
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
"""
Create a response using the requested model.
Raises `RequestError` if a model is not provided.
Raises `ResponseError` if the request could not be fulfilled.
Returns `GenerateResponse` if `stream` is `False`, otherwise returns a `GenerateResponse` generator.
"""
if not model:
raise Exception('must provide a model')
raise RequestError('must provide a model')
return self._request_stream(
'POST',
@ -87,21 +129,31 @@ class Client(BaseClient):
def chat(
self,
model: str = '',
messages: Optional[List[Message]] = None,
messages: Optional[Sequence[Message]] = None,
stream: bool = False,
format: str = '',
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
"""
Create a chat response using the requested model.
Raises `RequestError` if a model is not provided.
Raises `ResponseError` if the request could not be fulfilled.
Returns `ChatResponse` if `stream` is `False`, otherwise returns a `ChatResponse` generator.
"""
if not model:
raise Exception('must provide a model')
raise RequestError('must provide a model')
for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of strings')
raise TypeError('messages must be a list of Message or dict-like objects')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"')
raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"')
if not message.get('content'):
raise Exception('messages must contain content')
raise RequestError('messages must contain content')
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]
@ -118,12 +170,28 @@ class Client(BaseClient):
stream=stream,
)
def embeddings(self, model: str = '', prompt: str = '', options: Optional[Options] = None) -> Sequence[float]:
return self._request(
'POST',
'/api/embeddings',
json={
'model': model,
'prompt': prompt,
'options': options or {},
},
).json()
def pull(
self,
model: str,
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
"""
Raises `ResponseError` if the request could not be fulfilled.
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
return self._request_stream(
'POST',
'/api/pull',
@ -141,6 +209,11 @@ class Client(BaseClient):
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
"""
Raises `ResponseError` if the request could not be fulfilled.
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
return self._request_stream(
'POST',
'/api/push',
@ -159,12 +232,17 @@ class Client(BaseClient):
modelfile: Optional[str] = None,
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
"""
Raises `ResponseError` if the request could not be fulfilled.
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
if (realpath := _as_path(path)) and realpath.exists():
modelfile = self._parse_modelfile(realpath.read_text(), base=realpath.parent)
elif modelfile:
modelfile = self._parse_modelfile(modelfile)
else:
raise Exception('must provide either path or modelfile')
raise RequestError('must provide either path or modelfile')
return self._request_stream(
'POST',
@ -205,8 +283,8 @@ class Client(BaseClient):
try:
self._request('HEAD', f'/api/blobs/{digest}')
except httpx.HTTPStatusError as e:
if e.response.status_code != 404:
except ResponseError as e:
if e.status_code != 404:
raise
with open(path, 'rb') as r:
@ -219,7 +297,7 @@ class Client(BaseClient):
return {'status': 'success' if response.status_code == 200 else 'error'}
def list(self) -> Mapping[str, Any]:
return self._request('GET', '/api/tags').json().get('models', [])
return self._request('GET', '/api/tags').json()
def copy(self, source: str, target: str) -> Mapping[str, Any]:
response = self._request('POST', '/api/copy', json={'source': source, 'destination': target})
@ -230,21 +308,32 @@ class Client(BaseClient):
class AsyncClient(BaseClient):
def __init__(self, base_url: Optional[str] = None) -> None:
super().__init__(httpx.AsyncClient, base_url)
def __init__(self, base_url: Optional[str] = None, **kwargs) -> None:
super().__init__(httpx.AsyncClient, base_url, **kwargs)
async def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
response = await self._client.request(method, url, **kwargs)
response.raise_for_status()
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise ResponseError(e.response.text, e.response.status_code) from None
return response
async def _stream(self, method: str, url: str, **kwargs) -> AsyncIterator[Mapping[str, Any]]:
async def inner():
async with self._client.stream(method, url, **kwargs) as r:
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
e.response.read()
raise ResponseError(e.response.text, e.response.status_code) from None
async for line in r.aiter_lines():
partial = json.loads(line)
if e := partial.get('error'):
raise Exception(e)
raise ResponseError(e)
yield partial
return inner()
@ -267,15 +356,24 @@ class AsyncClient(BaseClient):
prompt: str = '',
system: str = '',
template: str = '',
context: Optional[List[int]] = None,
context: Optional[Sequence[int]] = None,
stream: bool = False,
raw: bool = False,
format: str = '',
images: Optional[List[AnyStr]] = None,
format: Literal['', 'json'] = '',
images: Optional[Sequence[AnyStr]] = None,
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
"""
Create a response using the requested model.
Raises `RequestError` if a model is not provided.
Raises `ResponseError` if the request could not be fulfilled.
Returns `GenerateResponse` if `stream` is `False`, otherwise returns an asynchronous `GenerateResponse` generator.
"""
if not model:
raise Exception('must provide a model')
raise RequestError('must provide a model')
return await self._request_stream(
'POST',
@ -298,21 +396,30 @@ class AsyncClient(BaseClient):
async def chat(
self,
model: str = '',
messages: Optional[List[Message]] = None,
messages: Optional[Sequence[Message]] = None,
stream: bool = False,
format: str = '',
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
"""
Create a chat response using the requested model.
Raises `RequestError` if a model is not provided.
Raises `ResponseError` if the request could not be fulfilled.
Returns `ChatResponse` if `stream` is `False`, otherwise returns an asynchronous `ChatResponse` generator.
"""
if not model:
raise Exception('must provide a model')
raise RequestError('must provide a model')
for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of strings')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"')
raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"')
if not message.get('content'):
raise Exception('messages must contain content')
raise RequestError('messages must contain content')
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]
@ -329,12 +436,30 @@ class AsyncClient(BaseClient):
stream=stream,
)
async def embeddings(self, model: str = '', prompt: str = '', options: Optional[Options] = None) -> Sequence[float]:
response = await self._request(
'POST',
'/api/embeddings',
json={
'model': model,
'prompt': prompt,
'options': options or {},
},
)
return response.json()
async def pull(
self,
model: str,
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
"""
Raises `ResponseError` if the request could not be fulfilled.
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
return await self._request_stream(
'POST',
'/api/pull',
@ -352,6 +477,11 @@ class AsyncClient(BaseClient):
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
"""
Raises `ResponseError` if the request could not be fulfilled.
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
return await self._request_stream(
'POST',
'/api/push',
@ -370,12 +500,17 @@ class AsyncClient(BaseClient):
modelfile: Optional[str] = None,
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
"""
Raises `ResponseError` if the request could not be fulfilled.
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
if (realpath := _as_path(path)) and realpath.exists():
modelfile = await self._parse_modelfile(realpath.read_text(), base=realpath.parent)
elif modelfile:
modelfile = await self._parse_modelfile(modelfile)
else:
raise Exception('must provide either path or modelfile')
raise RequestError('must provide either path or modelfile')
return await self._request_stream(
'POST',
@ -416,8 +551,8 @@ class AsyncClient(BaseClient):
try:
await self._request('HEAD', f'/api/blobs/{digest}')
except httpx.HTTPStatusError as e:
if e.response.status_code != 404:
except ResponseError as e:
if e.status_code != 404:
raise
async def upload_bytes():
@ -438,7 +573,7 @@ class AsyncClient(BaseClient):
async def list(self) -> Mapping[str, Any]:
response = await self._request('GET', '/api/tags')
return response.json().get('models', [])
return response.json()
async def copy(self, source: str, target: str) -> Mapping[str, Any]:
response = await self._request('POST', '/api/copy', json={'source': source, 'destination': target})
@ -455,7 +590,7 @@ def _encode_image(image) -> str:
elif b := _as_bytesio(image):
b64 = b64encode(b.read())
else:
raise Exception('images must be a list of bytes, path-like objects, or file-like objects')
raise RequestError('images must be a list of bytes, path-like objects, or file-like objects')
return b64.decode('utf-8')

View File

@ -1,4 +1,5 @@
from typing import Any, TypedDict, List
import json
from typing import Any, TypedDict, Sequence, Literal
import sys
@ -8,10 +9,85 @@ else:
from typing import NotRequired
class BaseGenerateResponse(TypedDict):
model: str
"Model used to generate response."
created_at: str
"Time when the request was created."
done: bool
"True if response is complete, otherwise False. Useful for streaming to detect the final response."
total_duration: int
"Total duration in nanoseconds."
load_duration: int
"Load duration in nanoseconds."
prompt_eval_count: int
"Number of tokens evaluated in the prompt."
prompt_eval_duration: int
"Duration of evaluating the prompt in nanoseconds."
eval_count: int
"Number of tokens evaluated in inference."
eval_duration: int
"Duration of evaluating inference in nanoseconds."
class GenerateResponse(BaseGenerateResponse):
"""
Response returned by generate requests.
"""
response: str
"Response content. When streaming, this contains a fragment of the response."
context: Sequence[int]
"Tokenized history up to the point of the response."
class Message(TypedDict):
role: str
"""
Chat message.
"""
role: Literal['user', 'assistant', 'system']
"Assumed role of the message. Response messages always has role 'assistant'."
content: str
images: NotRequired[List[Any]]
"Content of the message. Response messages contains message fragments when streaming."
images: NotRequired[Sequence[Any]]
"""
Optional list of image data for multimodal models.
Valid input types are:
- `str` or path-like object: path to image file
- `bytes` or bytes-like object: raw image data
Valid image formats depend on the model. See the model card for more information.
"""
class ChatResponse(BaseGenerateResponse):
"""
Response returned by chat requests.
"""
message: Message
"Response message."
class ProgressResponse(TypedDict):
status: str
completed: int
total: int
digest: str
class Options(TypedDict, total=False):
@ -50,4 +126,36 @@ class Options(TypedDict, total=False):
mirostat_tau: float
mirostat_eta: float
penalize_newline: bool
stop: List[str]
stop: Sequence[str]
class RequestError(Exception):
"""
Common class for request errors.
"""
def __init__(self, content: str):
super().__init__(content)
self.content = content
"Reason for the error."
class ResponseError(Exception):
"""
Common class for response errors.
"""
def __init__(self, content: str, status_code: int = -1):
try:
# try to parse content as JSON and extract 'error'
# fallback to raw content if JSON parsing fails
content = json.loads(content).get('error', content)
except json.JSONDecodeError:
...
super().__init__(content)
self.content = content
"Reason for the error."
self.status_code = status_code
"HTTP status code of the response."