mirror of
https://github.com/ollama/ollama-python.git
synced 2026-01-13 13:47:17 +08:00
client: expose resource cleanup methods (#444)
This commit is contained in:
parent
115792583e
commit
d1d704050b
@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
import ipaddress
|
||||
import json
|
||||
import os
|
||||
@ -75,7 +76,7 @@ from ollama._types import (
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class BaseClient:
|
||||
class BaseClient(contextlib.AbstractContextManager, contextlib.AbstractAsyncContextManager):
|
||||
def __init__(
|
||||
self,
|
||||
client,
|
||||
@ -116,6 +117,12 @@ class BaseClient:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
|
||||
CONNECTION_ERROR_MESSAGE = 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download'
|
||||
|
||||
@ -124,6 +131,9 @@ class Client(BaseClient):
|
||||
def __init__(self, host: Optional[str] = None, **kwargs) -> None:
|
||||
super().__init__(httpx.Client, host, **kwargs)
|
||||
|
||||
def close(self):
|
||||
self._client.close()
|
||||
|
||||
def _request_raw(self, *args, **kwargs):
|
||||
try:
|
||||
r = self._client.request(*args, **kwargs)
|
||||
@ -702,6 +712,9 @@ class AsyncClient(BaseClient):
|
||||
def __init__(self, host: Optional[str] = None, **kwargs) -> None:
|
||||
super().__init__(httpx.AsyncClient, host, **kwargs)
|
||||
|
||||
async def close(self):
|
||||
await self._client.aclose()
|
||||
|
||||
async def _request_raw(self, *args, **kwargs):
|
||||
try:
|
||||
r = await self._client.request(*args, **kwargs)
|
||||
|
||||
@ -1347,3 +1347,33 @@ def test_client_explicit_bearer_header_overrides_env(monkeypatch: pytest.MonkeyP
|
||||
client = Client(headers={'Authorization': 'Bearer explicit-token'})
|
||||
assert client._client.headers['authorization'] == 'Bearer explicit-token'
|
||||
client.web_search('override check')
|
||||
|
||||
|
||||
def test_client_close():
|
||||
client = Client()
|
||||
client.close()
|
||||
assert client._client.is_closed
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_async_client_close():
|
||||
client = AsyncClient()
|
||||
await client.close()
|
||||
assert client._client.is_closed
|
||||
|
||||
|
||||
def test_client_context_manager():
|
||||
with Client() as client:
|
||||
assert isinstance(client, Client)
|
||||
assert not client._client.is_closed
|
||||
|
||||
assert client._client.is_closed
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_async_client_context_manager():
|
||||
async with AsyncClient() as client:
|
||||
assert isinstance(client, AsyncClient)
|
||||
assert not client._client.is_closed
|
||||
|
||||
assert client._client.is_closed
|
||||
|
||||
Loading…
Reference in New Issue
Block a user