diff --git a/ollama/__init__.py b/ollama/__init__.py index d48df1b..c4ebe2b 100644 --- a/ollama/__init__.py +++ b/ollama/__init__.py @@ -2,8 +2,11 @@ from ollama._client import Client, AsyncClient from ollama._types import ( GenerateResponse, ChatResponse, + ProgressResponse, Message, Options, + RequestError, + ResponseError, ) __all__ = [ @@ -11,8 +14,11 @@ __all__ = [ 'AsyncClient', 'GenerateResponse', 'ChatResponse', + 'ProgressResponse', 'Message', 'Options', + 'RequestError', + 'ResponseError', 'generate', 'chat', 'pull', diff --git a/ollama/_client.py b/ollama/_client.py index 09b4a47..d3b2a4a 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -16,7 +16,7 @@ if sys.version_info < (3, 9): else: from collections.abc import Iterator, AsyncIterator -from ollama._types import Message, Options +from ollama._types import Message, Options, RequestError, ResponseError class BaseClient: @@ -42,15 +42,26 @@ class Client(BaseClient): def _request(self, method: str, url: str, **kwargs) -> httpx.Response: response = self._client.request(method, url, **kwargs) - response.raise_for_status() + + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise ResponseError(e.response.text, e.response.status_code) from None + return response def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any]]: with self._client.stream(method, url, **kwargs) as r: + try: + r.raise_for_status() + except httpx.HTTPStatusError as e: + e.response.read() + raise ResponseError(e.response.text, e.response.status_code) from None + for line in r.iter_lines(): partial = json.loads(line) if e := partial.get('error'): - raise Exception(e) + raise ResponseError(e) yield partial def _request_stream( @@ -75,11 +86,17 @@ class Client(BaseClient): options: Optional[Options] = None, ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: """ + Create a response using the requested model. + + Raises `RequestError` if a model is not provided. + + Raises `ResponseError` if the request could not be fulfilled. + Returns `GenerateResponse` if `stream` is `False`, otherwise returns a `GenerateResponse` generator. """ if not model: - raise Exception('must provide a model') + raise RequestError('must provide a model') return self._request_stream( 'POST', @@ -108,19 +125,25 @@ class Client(BaseClient): options: Optional[Options] = None, ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: """ + Create a chat response using the requested model. + + Raises `RequestError` if a model is not provided. + + Raises `ResponseError` if the request could not be fulfilled. + Returns `ChatResponse` if `stream` is `False`, otherwise returns a `ChatResponse` generator. """ if not model: - raise Exception('must provide a model') + raise RequestError('must provide a model') for message in messages or []: if not isinstance(message, dict): raise TypeError('messages must be a list of strings') if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']: - raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"') + raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"') if not message.get('content'): - raise Exception('messages must contain content') + raise RequestError('messages must contain content') if images := message.get('images'): message['images'] = [_encode_image(image) for image in images] @@ -144,6 +167,8 @@ class Client(BaseClient): stream: bool = False, ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: """ + Raises `ResponseError` if the request could not be fulfilled. + Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator. """ return self._request_stream( @@ -164,6 +189,8 @@ class Client(BaseClient): stream: bool = False, ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: """ + Raises `ResponseError` if the request could not be fulfilled. + Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator. """ return self._request_stream( @@ -185,6 +212,8 @@ class Client(BaseClient): stream: bool = False, ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: """ + Raises `ResponseError` if the request could not be fulfilled. + Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator. """ if (realpath := _as_path(path)) and realpath.exists(): @@ -192,7 +221,7 @@ class Client(BaseClient): elif modelfile: modelfile = self._parse_modelfile(modelfile) else: - raise Exception('must provide either path or modelfile') + raise RequestError('must provide either path or modelfile') return self._request_stream( 'POST', @@ -233,8 +262,8 @@ class Client(BaseClient): try: self._request('HEAD', f'/api/blobs/{digest}') - except httpx.HTTPStatusError as e: - if e.response.status_code != 404: + except ResponseError as e: + if e.status_code != 404: raise with open(path, 'rb') as r: @@ -263,16 +292,27 @@ class AsyncClient(BaseClient): async def _request(self, method: str, url: str, **kwargs) -> httpx.Response: response = await self._client.request(method, url, **kwargs) - response.raise_for_status() + + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise ResponseError(e.response.text, e.response.status_code) from None + return response async def _stream(self, method: str, url: str, **kwargs) -> AsyncIterator[Mapping[str, Any]]: async def inner(): async with self._client.stream(method, url, **kwargs) as r: + try: + r.raise_for_status() + except httpx.HTTPStatusError as e: + e.response.read() + raise ResponseError(e.response.text, e.response.status_code) from None + async for line in r.aiter_lines(): partial = json.loads(line) if e := partial.get('error'): - raise Exception(e) + raise ResponseError(e) yield partial return inner() @@ -303,10 +343,16 @@ class AsyncClient(BaseClient): options: Optional[Options] = None, ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: """ + Create a response using the requested model. + + Raises `RequestError` if a model is not provided. + + Raises `ResponseError` if the request could not be fulfilled. + Returns `GenerateResponse` if `stream` is `False`, otherwise returns an asynchronous `GenerateResponse` generator. """ if not model: - raise Exception('must provide a model') + raise RequestError('must provide a model') return await self._request_stream( 'POST', @@ -335,18 +381,24 @@ class AsyncClient(BaseClient): options: Optional[Options] = None, ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: """ + Create a chat response using the requested model. + + Raises `RequestError` if a model is not provided. + + Raises `ResponseError` if the request could not be fulfilled. + Returns `ChatResponse` if `stream` is `False`, otherwise returns an asynchronous `ChatResponse` generator. """ if not model: - raise Exception('must provide a model') + raise RequestError('must provide a model') for message in messages or []: if not isinstance(message, dict): raise TypeError('messages must be a list of strings') if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']: - raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"') + raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"') if not message.get('content'): - raise Exception('messages must contain content') + raise RequestError('messages must contain content') if images := message.get('images'): message['images'] = [_encode_image(image) for image in images] @@ -370,6 +422,8 @@ class AsyncClient(BaseClient): stream: bool = False, ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: """ + Raises `ResponseError` if the request could not be fulfilled. + Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator. """ return await self._request_stream( @@ -390,6 +444,8 @@ class AsyncClient(BaseClient): stream: bool = False, ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: """ + Raises `ResponseError` if the request could not be fulfilled. + Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator. """ return await self._request_stream( @@ -411,6 +467,8 @@ class AsyncClient(BaseClient): stream: bool = False, ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: """ + Raises `ResponseError` if the request could not be fulfilled. + Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator. """ if (realpath := _as_path(path)) and realpath.exists(): @@ -418,7 +476,7 @@ class AsyncClient(BaseClient): elif modelfile: modelfile = await self._parse_modelfile(modelfile) else: - raise Exception('must provide either path or modelfile') + raise RequestError('must provide either path or modelfile') return await self._request_stream( 'POST', @@ -459,8 +517,8 @@ class AsyncClient(BaseClient): try: await self._request('HEAD', f'/api/blobs/{digest}') - except httpx.HTTPStatusError as e: - if e.response.status_code != 404: + except ResponseError as e: + if e.status_code != 404: raise async def upload_bytes(): @@ -498,7 +556,7 @@ def _encode_image(image) -> str: elif b := _as_bytesio(image): b64 = b64encode(b.read()) else: - raise Exception('images must be a list of bytes, path-like objects, or file-like objects') + raise RequestError('images must be a list of bytes, path-like objects, or file-like objects') return b64.decode('utf-8') diff --git a/ollama/_types.py b/ollama/_types.py index 059dd1e..226626e 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -80,3 +80,28 @@ class Options(TypedDict, total=False): mirostat_eta: float penalize_newline: bool stop: Sequence[str] + + +class RequestError(Exception): + """ + Common class for request errors. + """ + + def __init__(self, content: str): + super().__init__(content) + self.content = content + "Reason for the error." + + +class ResponseError(Exception): + """ + Common class for response errors. + """ + + def __init__(self, content: str, status_code: int = -1): + super().__init__(content) + self.content = content + "Reason for the error." + + self.status_code = status_code + "HTTP status code of the response."