request/response errors

This commit is contained in:
Michael Yang 2024-01-10 11:31:58 -08:00
parent 2804a03d82
commit 7601947a35
3 changed files with 109 additions and 20 deletions

View File

@ -2,8 +2,11 @@ from ollama._client import Client, AsyncClient
from ollama._types import (
GenerateResponse,
ChatResponse,
ProgressResponse,
Message,
Options,
RequestError,
ResponseError,
)
__all__ = [
@ -11,8 +14,11 @@ __all__ = [
'AsyncClient',
'GenerateResponse',
'ChatResponse',
'ProgressResponse',
'Message',
'Options',
'RequestError',
'ResponseError',
'generate',
'chat',
'pull',

View File

@ -16,7 +16,7 @@ if sys.version_info < (3, 9):
else:
from collections.abc import Iterator, AsyncIterator
from ollama._types import Message, Options
from ollama._types import Message, Options, RequestError, ResponseError
class BaseClient:
@ -42,15 +42,26 @@ class Client(BaseClient):
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
response = self._client.request(method, url, **kwargs)
response.raise_for_status()
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise ResponseError(e.response.text, e.response.status_code) from None
return response
def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any]]:
with self._client.stream(method, url, **kwargs) as r:
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
e.response.read()
raise ResponseError(e.response.text, e.response.status_code) from None
for line in r.iter_lines():
partial = json.loads(line)
if e := partial.get('error'):
raise Exception(e)
raise ResponseError(e)
yield partial
def _request_stream(
@ -75,11 +86,17 @@ class Client(BaseClient):
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
"""
Create a response using the requested model.
Raises `RequestError` if a model is not provided.
Raises `ResponseError` if the request could not be fulfilled.
Returns `GenerateResponse` if `stream` is `False`, otherwise returns a `GenerateResponse` generator.
"""
if not model:
raise Exception('must provide a model')
raise RequestError('must provide a model')
return self._request_stream(
'POST',
@ -108,19 +125,25 @@ class Client(BaseClient):
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
"""
Create a chat response using the requested model.
Raises `RequestError` if a model is not provided.
Raises `ResponseError` if the request could not be fulfilled.
Returns `ChatResponse` if `stream` is `False`, otherwise returns a `ChatResponse` generator.
"""
if not model:
raise Exception('must provide a model')
raise RequestError('must provide a model')
for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of strings')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"')
raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"')
if not message.get('content'):
raise Exception('messages must contain content')
raise RequestError('messages must contain content')
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]
@ -144,6 +167,8 @@ class Client(BaseClient):
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
"""
Raises `ResponseError` if the request could not be fulfilled.
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
return self._request_stream(
@ -164,6 +189,8 @@ class Client(BaseClient):
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
"""
Raises `ResponseError` if the request could not be fulfilled.
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
return self._request_stream(
@ -185,6 +212,8 @@ class Client(BaseClient):
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
"""
Raises `ResponseError` if the request could not be fulfilled.
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
if (realpath := _as_path(path)) and realpath.exists():
@ -192,7 +221,7 @@ class Client(BaseClient):
elif modelfile:
modelfile = self._parse_modelfile(modelfile)
else:
raise Exception('must provide either path or modelfile')
raise RequestError('must provide either path or modelfile')
return self._request_stream(
'POST',
@ -233,8 +262,8 @@ class Client(BaseClient):
try:
self._request('HEAD', f'/api/blobs/{digest}')
except httpx.HTTPStatusError as e:
if e.response.status_code != 404:
except ResponseError as e:
if e.status_code != 404:
raise
with open(path, 'rb') as r:
@ -263,16 +292,27 @@ class AsyncClient(BaseClient):
async def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
response = await self._client.request(method, url, **kwargs)
response.raise_for_status()
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise ResponseError(e.response.text, e.response.status_code) from None
return response
async def _stream(self, method: str, url: str, **kwargs) -> AsyncIterator[Mapping[str, Any]]:
async def inner():
async with self._client.stream(method, url, **kwargs) as r:
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
e.response.read()
raise ResponseError(e.response.text, e.response.status_code) from None
async for line in r.aiter_lines():
partial = json.loads(line)
if e := partial.get('error'):
raise Exception(e)
raise ResponseError(e)
yield partial
return inner()
@ -303,10 +343,16 @@ class AsyncClient(BaseClient):
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
"""
Create a response using the requested model.
Raises `RequestError` if a model is not provided.
Raises `ResponseError` if the request could not be fulfilled.
Returns `GenerateResponse` if `stream` is `False`, otherwise returns an asynchronous `GenerateResponse` generator.
"""
if not model:
raise Exception('must provide a model')
raise RequestError('must provide a model')
return await self._request_stream(
'POST',
@ -335,18 +381,24 @@ class AsyncClient(BaseClient):
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
"""
Create a chat response using the requested model.
Raises `RequestError` if a model is not provided.
Raises `ResponseError` if the request could not be fulfilled.
Returns `ChatResponse` if `stream` is `False`, otherwise returns an asynchronous `ChatResponse` generator.
"""
if not model:
raise Exception('must provide a model')
raise RequestError('must provide a model')
for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of strings')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"')
raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"')
if not message.get('content'):
raise Exception('messages must contain content')
raise RequestError('messages must contain content')
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]
@ -370,6 +422,8 @@ class AsyncClient(BaseClient):
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
"""
Raises `ResponseError` if the request could not be fulfilled.
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
return await self._request_stream(
@ -390,6 +444,8 @@ class AsyncClient(BaseClient):
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
"""
Raises `ResponseError` if the request could not be fulfilled.
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
return await self._request_stream(
@ -411,6 +467,8 @@ class AsyncClient(BaseClient):
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
"""
Raises `ResponseError` if the request could not be fulfilled.
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
if (realpath := _as_path(path)) and realpath.exists():
@ -418,7 +476,7 @@ class AsyncClient(BaseClient):
elif modelfile:
modelfile = await self._parse_modelfile(modelfile)
else:
raise Exception('must provide either path or modelfile')
raise RequestError('must provide either path or modelfile')
return await self._request_stream(
'POST',
@ -459,8 +517,8 @@ class AsyncClient(BaseClient):
try:
await self._request('HEAD', f'/api/blobs/{digest}')
except httpx.HTTPStatusError as e:
if e.response.status_code != 404:
except ResponseError as e:
if e.status_code != 404:
raise
async def upload_bytes():
@ -498,7 +556,7 @@ def _encode_image(image) -> str:
elif b := _as_bytesio(image):
b64 = b64encode(b.read())
else:
raise Exception('images must be a list of bytes, path-like objects, or file-like objects')
raise RequestError('images must be a list of bytes, path-like objects, or file-like objects')
return b64.decode('utf-8')

View File

@ -80,3 +80,28 @@ class Options(TypedDict, total=False):
mirostat_eta: float
penalize_newline: bool
stop: Sequence[str]
class RequestError(Exception):
"""
Common class for request errors.
"""
def __init__(self, content: str):
super().__init__(content)
self.content = content
"Reason for the error."
class ResponseError(Exception):
"""
Common class for response errors.
"""
def __init__(self, content: str, status_code: int = -1):
super().__init__(content)
self.content = content
"Reason for the error."
self.status_code = status_code
"HTTP status code of the response."