From f5c8ee0a3e67cee0f8aeb12e1d2006cd67812b1b Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Mon, 13 Jan 2025 15:54:16 -0800 Subject: [PATCH] fix async client --- ollama/_client.py | 89 ++++++++++++----------------- tests/test_client.py | 131 ++++++++----------------------------------- 2 files changed, 59 insertions(+), 161 deletions(-) diff --git a/ollama/_client.py b/ollama/_client.py index 7958adb..fff16fc 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -546,24 +546,6 @@ class Client(BaseClient): stream=stream, ) - def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str: - base = Path.cwd() if base is None else base - - out = io.StringIO() - for line in io.StringIO(modelfile): - command, _, args = line.partition(' ') - if command.upper() not in ['FROM', 'ADAPTER']: - print(line, end='', file=out) - continue - - path = Path(args.strip()).expanduser() - path = path if path.is_absolute() else base / path - if path.exists(): - args = f'@{self.create_blob(path)}\n' - print(command, args, end='', file=out) - - return out.getvalue() - def create_blob(self, path: Union[str, Path]) -> str: sha256sum = sha256() with open(path, 'rb') as r: @@ -996,31 +978,49 @@ class AsyncClient(BaseClient): async def create( self, model: str, - path: Optional[Union[str, PathLike]] = None, - modelfile: Optional[str] = None, - *, quantize: Optional[str] = None, - stream: Literal[False] = False, + from_: Optional[str] = None, + files: Optional[dict[str, str]] = None, + adapters: Optional[dict[str, str]] = None, + template: Optional[str] = None, + license: Optional[Union[str, list[str]]] = None, + system: Optional[str] = None, + parameters: Optional[Union[Mapping[str, Any], Options]] = None, + messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, + *, + stream: Literal[True] = True, ) -> ProgressResponse: ... @overload async def create( self, model: str, - path: Optional[Union[str, PathLike]] = None, - modelfile: Optional[str] = None, - *, quantize: Optional[str] = None, + from_: Optional[str] = None, + files: Optional[dict[str, str]] = None, + adapters: Optional[dict[str, str]] = None, + template: Optional[str] = None, + license: Optional[Union[str, list[str]]] = None, + system: Optional[str] = None, + parameters: Optional[Union[Mapping[str, Any], Options]] = None, + messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, + *, stream: Literal[True] = True, ) -> AsyncIterator[ProgressResponse]: ... async def create( self, model: str, - path: Optional[Union[str, PathLike]] = None, - modelfile: Optional[str] = None, - *, quantize: Optional[str] = None, + from_: Optional[str] = None, + files: Optional[dict[str, str]] = None, + adapters: Optional[dict[str, str]] = None, + template: Optional[str] = None, + license: Optional[Union[str, list[str]]] = None, + system: Optional[str] = None, + parameters: Optional[Union[Mapping[str, Any], Options]] = None, + messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, + *, stream: bool = False, ) -> Union[ProgressResponse, AsyncIterator[ProgressResponse]]: """ @@ -1028,12 +1028,6 @@ class AsyncClient(BaseClient): Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator. """ - if (realpath := _as_path(path)) and realpath.exists(): - modelfile = await self._parse_modelfile(realpath.read_text(), base=realpath.parent) - elif modelfile: - modelfile = await self._parse_modelfile(modelfile) - else: - raise RequestError('must provide either path or modelfile') return await self._request( ProgressResponse, @@ -1041,31 +1035,20 @@ class AsyncClient(BaseClient): '/api/create', json=CreateRequest( model=model, - modelfile=modelfile, stream=stream, quantize=quantize, + from_=from_, + files=files, + adapters=adapters, + license=license, + template=template, + system=system, + parameters=parameters, + messages=messages, ).model_dump(exclude_none=True), stream=stream, ) - async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str: - base = Path.cwd() if base is None else base - - out = io.StringIO() - for line in io.StringIO(modelfile): - command, _, args = line.partition(' ') - if command.upper() not in ['FROM', 'ADAPTER']: - print(line, end='', file=out) - continue - - path = Path(args.strip()).expanduser() - path = path if path.is_absolute() else base / path - if path.exists(): - args = f'@{await self.create_blob(path)}\n' - print(command, args, end='', file=out) - - return out.getvalue() - async def create_blob(self, path: Union[str, Path]) -> str: sha256sum = sha256() with open(path, 'rb') as r: diff --git a/tests/test_client.py b/tests/test_client.py index 8085cf7..2a83eb0 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -933,86 +933,13 @@ async def test_async_client_push_stream(httpserver: HTTPServer): @pytest.mark.asyncio -async def test_async_client_create_path(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200)) +async def test_async_client_create_with_blob(httpserver: HTTPServer): httpserver.expect_ordered_request( '/api/create', method='POST', json={ 'model': 'dummy', - 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', - 'stream': False, - }, - ).respond_with_json({'status': 'success'}) - - client = AsyncClient(httpserver.url_for('/')) - - with tempfile.NamedTemporaryFile() as modelfile: - with tempfile.NamedTemporaryFile() as blob: - modelfile.write(f'FROM {blob.name}'.encode('utf-8')) - modelfile.flush() - - response = await client.create('dummy', path=modelfile.name) - assert response['status'] == 'success' - - -@pytest.mark.asyncio -async def test_async_client_create_path_relative(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200)) - httpserver.expect_ordered_request( - '/api/create', - method='POST', - json={ - 'model': 'dummy', - 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', - 'stream': False, - }, - ).respond_with_json({'status': 'success'}) - - client = AsyncClient(httpserver.url_for('/')) - - with tempfile.NamedTemporaryFile() as modelfile: - with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob: - modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8')) - modelfile.flush() - - response = await client.create('dummy', path=modelfile.name) - assert response['status'] == 'success' - - -@pytest.mark.asyncio -async def test_async_client_create_path_user_home(httpserver: HTTPServer, userhomedir): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200)) - httpserver.expect_ordered_request( - '/api/create', - method='POST', - json={ - 'model': 'dummy', - 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', - 'stream': False, - }, - ).respond_with_json({'status': 'success'}) - - client = AsyncClient(httpserver.url_for('/')) - - with tempfile.NamedTemporaryFile() as modelfile: - with tempfile.NamedTemporaryFile(dir=userhomedir) as blob: - modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8')) - modelfile.flush() - - response = await client.create('dummy', path=modelfile.name) - assert response['status'] == 'success' - - -@pytest.mark.asyncio -async def test_async_client_create_modelfile(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200)) - httpserver.expect_ordered_request( - '/api/create', - method='POST', - json={ - 'model': 'dummy', - 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', + 'files': {'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'}, 'stream': False, }, ).respond_with_json({'status': 'success'}) @@ -1020,30 +947,25 @@ async def test_async_client_create_modelfile(httpserver: HTTPServer): client = AsyncClient(httpserver.url_for('/')) with tempfile.NamedTemporaryFile() as blob: - response = await client.create('dummy', modelfile=f'FROM {blob.name}') + response = await client.create('dummy', files={'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'}) assert response['status'] == 'success' @pytest.mark.asyncio -async def test_async_client_create_modelfile_roundtrip(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200)) +async def test_async_client_create_with_parameters_roundtrip(httpserver: HTTPServer): httpserver.expect_ordered_request( '/api/create', method='POST', json={ 'model': 'dummy', - 'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 -TEMPLATE """[INST] <>{{.System}}<> -{{.Prompt}} [/INST]""" -SYSTEM """ -Use -multiline -strings. -""" -PARAMETER stop [INST] -PARAMETER stop [/INST] -PARAMETER stop <> -PARAMETER stop <>''', + 'quantize': 'q4_k_m', + 'from': 'mymodel', + 'adapters': {'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'}, + 'template': '[INST] <>{{.System}}<>\n{{.Prompt}} [/INST]', + 'license': 'this is my license', + 'system': '\nUse\nmultiline\nstrings.\n', + 'parameters': {'stop': ['[INST]', '[/INST]', '<>', '<>'], 'pi': 3.14159}, + 'messages': [{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}], 'stream': False, }, ).respond_with_json({'status': 'success'}) @@ -1053,22 +975,15 @@ PARAMETER stop <>''', with tempfile.NamedTemporaryFile() as blob: response = await client.create( 'dummy', - modelfile='\n'.join( - [ - f'FROM {blob.name}', - 'TEMPLATE """[INST] <>{{.System}}<>', - '{{.Prompt}} [/INST]"""', - 'SYSTEM """', - 'Use', - 'multiline', - 'strings.', - '"""', - 'PARAMETER stop [INST]', - 'PARAMETER stop [/INST]', - 'PARAMETER stop <>', - 'PARAMETER stop <>', - ] - ), + quantize='q4_k_m', + from_='mymodel', + adapters={'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'}, + template='[INST] <>{{.System}}<>\n{{.Prompt}} [/INST]', + license='this is my license', + system='\nUse\nmultiline\nstrings.\n', + parameters={'stop': ['[INST]', '[/INST]', '<>', '<>'], 'pi': 3.14159}, + messages=[{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}], + stream=False, ) assert response['status'] == 'success' @@ -1080,14 +995,14 @@ async def test_async_client_create_from_library(httpserver: HTTPServer): method='POST', json={ 'model': 'dummy', - 'modelfile': 'FROM llama2', + 'from': 'llama2', 'stream': False, }, ).respond_with_json({'status': 'success'}) client = AsyncClient(httpserver.url_for('/')) - response = await client.create('dummy', modelfile='FROM llama2') + response = await client.create('dummy', from_='llama2') assert response['status'] == 'success'