From 967fd657f1fa5722dffed124ab6535ecf4a9e060 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Thu, 16 Jan 2025 13:55:17 -0800 Subject: [PATCH] client: improve error messaging on connection failure (#398) *iImprove error messaging on connection failure --- examples/create.py | 8 ++++---- ollama/_client.py | 15 +++++++++++---- ollama/_types.py | 3 +++ tests/test_client.py | 29 ++++++++++++++++++++++++++++- 4 files changed, 46 insertions(+), 9 deletions(-) diff --git a/examples/create.py b/examples/create.py index dfc9094..14967a9 100755 --- a/examples/create.py +++ b/examples/create.py @@ -2,9 +2,9 @@ from ollama import Client client = Client() response = client.create( - model='my-assistant', - from_='llama3.2', - system="You are mario from Super Mario Bros.", - stream=False + model='my-assistant', + from_='llama3.2', + system='You are mario from Super Mario Bros.', + stream=False, ) print(response.status) diff --git a/ollama/_client.py b/ollama/_client.py index c0fccd4..cbe43c9 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -106,17 +106,22 @@ class BaseClient: ) +CONNECTION_ERROR_MESSAGE = 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download' + + class Client(BaseClient): def __init__(self, host: Optional[str] = None, **kwargs) -> None: super().__init__(httpx.Client, host, **kwargs) def _request_raw(self, *args, **kwargs): - r = self._client.request(*args, **kwargs) try: + r = self._client.request(*args, **kwargs) r.raise_for_status() + return r except httpx.HTTPStatusError as e: raise ResponseError(e.response.text, e.response.status_code) from None - return r + except httpx.ConnectError: + raise ConnectionError(CONNECTION_ERROR_MESSAGE) from None @overload def _request( @@ -613,12 +618,14 @@ class AsyncClient(BaseClient): super().__init__(httpx.AsyncClient, host, **kwargs) async def _request_raw(self, *args, **kwargs): - r = await self._client.request(*args, **kwargs) try: + r = await self._client.request(*args, **kwargs) r.raise_for_status() + return r except httpx.HTTPStatusError as e: raise ResponseError(e.response.text, e.response.status_code) from None - return r + except httpx.ConnectError: + raise ConnectionError(CONNECTION_ERROR_MESSAGE) from None @overload async def _request( diff --git a/ollama/_types.py b/ollama/_types.py index d70a4ac..710c536 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -535,3 +535,6 @@ class ResponseError(Exception): self.status_code = status_code 'HTTP status code of the response.' + + def __str__(self) -> str: + return f'{self.error} (status code: {self.status_code})' diff --git a/tests/test_client.py b/tests/test_client.py index 67e87b9..e74c936 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, ValidationError from pytest_httpserver import HTTPServer, URIPattern from werkzeug.wrappers import Request, Response -from ollama._client import AsyncClient, Client, _copy_tools +from ollama._client import CONNECTION_ERROR_MESSAGE, AsyncClient, Client, _copy_tools PNG_BASE64 = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC' PNG_BYTES = base64.b64decode(PNG_BASE64) @@ -1112,3 +1112,30 @@ def test_tool_validation(): with pytest.raises(ValidationError): invalid_tool = {'type': 'invalid_type', 'function': {'name': 'test'}} list(_copy_tools([invalid_tool])) + + +def test_client_connection_error(): + client = Client('http://localhost:1234') + + with pytest.raises(ConnectionError, match=CONNECTION_ERROR_MESSAGE): + client.chat('model', messages=[{'role': 'user', 'content': 'prompt'}]) + with pytest.raises(ConnectionError, match=CONNECTION_ERROR_MESSAGE): + client.chat('model', messages=[{'role': 'user', 'content': 'prompt'}]) + with pytest.raises(ConnectionError, match=CONNECTION_ERROR_MESSAGE): + client.generate('model', 'prompt') + with pytest.raises(ConnectionError, match=CONNECTION_ERROR_MESSAGE): + client.show('model') + + +@pytest.mark.asyncio +async def test_async_client_connection_error(): + client = AsyncClient('http://localhost:1234') + with pytest.raises(ConnectionError) as exc_info: + await client.chat('model', messages=[{'role': 'user', 'content': 'prompt'}]) + assert str(exc_info.value) == 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download' + with pytest.raises(ConnectionError) as exc_info: + await client.generate('model', 'prompt') + assert str(exc_info.value) == 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download' + with pytest.raises(ConnectionError) as exc_info: + await client.show('model') + assert str(exc_info.value) == 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download'