diff --git a/ollama/_client.py b/ollama/_client.py index 5908dfc..c1eaf85 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -259,13 +259,16 @@ class Client(BaseClient): out = io.StringIO() for line in io.StringIO(modelfile): command, _, args = line.partition(' ') - if command.upper() in ['FROM', 'ADAPTER']: - path = Path(args.strip()).expanduser() - path = path if path.is_absolute() else base / path - if path.exists(): - args = f'@{self._create_blob(path)}' + if command.upper() not in ['FROM', 'ADAPTER']: + print(line, end='', file=out) + continue + + path = Path(args.strip()).expanduser() + path = path if path.is_absolute() else base / path + if path.exists(): + args = f'@{self._create_blob(path)}\n' + print(command, args, end='', file=out) - print(command, args, file=out) return out.getvalue() def _create_blob(self, path: Union[str, Path]) -> str: @@ -527,13 +530,16 @@ class AsyncClient(BaseClient): out = io.StringIO() for line in io.StringIO(modelfile): command, _, args = line.partition(' ') - if command.upper() in ['FROM', 'ADAPTER']: - path = Path(args).expanduser() - path = path if path.is_absolute() else base / path - if path.exists(): - args = f'@{await self._create_blob(path)}' + if command.upper() not in ['FROM', 'ADAPTER']: + print(line, end='', file=out) + continue + + path = Path(args.strip()).expanduser() + path = path if path.is_absolute() else base / path + if path.exists(): + args = f'@{await self._create_blob(path)}\n' + print(command, args, end='', file=out) - print(command, args, file=out) return out.getvalue() async def _create_blob(self, path: Union[str, Path]) -> str: diff --git a/tests/test_client.py b/tests/test_client.py index 6afbe70..08aa789 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -416,13 +416,61 @@ def test_client_create_modelfile(httpserver: HTTPServer): assert isinstance(response, dict) +def test_client_create_modelfile_roundtrip(httpserver: HTTPServer): + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request( + '/api/create', + method='POST', + json={ + 'name': 'dummy', + 'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 +TEMPLATE """[INST] <>{{.System}}<> +{{.Prompt}} [/INST]""" +SYSTEM """ +Use +multiline +strings. +""" +PARAMETER stop [INST] +PARAMETER stop [/INST] +PARAMETER stop <> +PARAMETER stop <>''', + 'stream': False, + }, + ).respond_with_json({}) + + client = Client(httpserver.url_for('/')) + + with tempfile.NamedTemporaryFile() as blob: + response = client.create( + 'dummy', + modelfile='\n'.join( + [ + f'FROM {blob.name}', + 'TEMPLATE """[INST] <>{{.System}}<>', + '{{.Prompt}} [/INST]"""', + 'SYSTEM """', + 'Use', + 'multiline', + 'strings.', + '"""', + 'PARAMETER stop [INST]', + 'PARAMETER stop [/INST]', + 'PARAMETER stop <>', + 'PARAMETER stop <>', + ] + ), + ) + assert isinstance(response, dict) + + def test_client_create_from_library(httpserver: HTTPServer): httpserver.expect_ordered_request( '/api/create', method='POST', json={ 'name': 'dummy', - 'modelfile': 'FROM llama2\n', + 'modelfile': 'FROM llama2', 'stream': False, }, ).respond_with_json({}) @@ -820,6 +868,55 @@ async def test_async_client_create_modelfile(httpserver: HTTPServer): assert isinstance(response, dict) +@pytest.mark.asyncio +async def test_async_client_create_modelfile_roundtrip(httpserver: HTTPServer): + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request( + '/api/create', + method='POST', + json={ + 'name': 'dummy', + 'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 +TEMPLATE """[INST] <>{{.System}}<> +{{.Prompt}} [/INST]""" +SYSTEM """ +Use +multiline +strings. +""" +PARAMETER stop [INST] +PARAMETER stop [/INST] +PARAMETER stop <> +PARAMETER stop <>''', + 'stream': False, + }, + ).respond_with_json({}) + + client = AsyncClient(httpserver.url_for('/')) + + with tempfile.NamedTemporaryFile() as blob: + response = await client.create( + 'dummy', + modelfile='\n'.join( + [ + f'FROM {blob.name}', + 'TEMPLATE """[INST] <>{{.System}}<>', + '{{.Prompt}} [/INST]"""', + 'SYSTEM """', + 'Use', + 'multiline', + 'strings.', + '"""', + 'PARAMETER stop [INST]', + 'PARAMETER stop [/INST]', + 'PARAMETER stop <>', + 'PARAMETER stop <>', + ] + ), + ) + assert isinstance(response, dict) + + @pytest.mark.asyncio async def test_async_client_create_from_library(httpserver: HTTPServer): httpserver.expect_ordered_request( @@ -827,7 +924,7 @@ async def test_async_client_create_from_library(httpserver: HTTPServer): method='POST', json={ 'name': 'dummy', - 'modelfile': 'FROM llama2\n', + 'modelfile': 'FROM llama2', 'stream': False, }, ).respond_with_json({})