mirror of
https://github.com/ollama/ollama-python.git
synced 2026-01-13 21:57:16 +08:00
fix async client
This commit is contained in:
parent
a0388b2e32
commit
f5c8ee0a3e
@ -546,24 +546,6 @@ class Client(BaseClient):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
|
||||
base = Path.cwd() if base is None else base
|
||||
|
||||
out = io.StringIO()
|
||||
for line in io.StringIO(modelfile):
|
||||
command, _, args = line.partition(' ')
|
||||
if command.upper() not in ['FROM', 'ADAPTER']:
|
||||
print(line, end='', file=out)
|
||||
continue
|
||||
|
||||
path = Path(args.strip()).expanduser()
|
||||
path = path if path.is_absolute() else base / path
|
||||
if path.exists():
|
||||
args = f'@{self.create_blob(path)}\n'
|
||||
print(command, args, end='', file=out)
|
||||
|
||||
return out.getvalue()
|
||||
|
||||
def create_blob(self, path: Union[str, Path]) -> str:
|
||||
sha256sum = sha256()
|
||||
with open(path, 'rb') as r:
|
||||
@ -996,31 +978,49 @@ class AsyncClient(BaseClient):
|
||||
async def create(
|
||||
self,
|
||||
model: str,
|
||||
path: Optional[Union[str, PathLike]] = None,
|
||||
modelfile: Optional[str] = None,
|
||||
*,
|
||||
quantize: Optional[str] = None,
|
||||
stream: Literal[False] = False,
|
||||
from_: Optional[str] = None,
|
||||
files: Optional[dict[str, str]] = None,
|
||||
adapters: Optional[dict[str, str]] = None,
|
||||
template: Optional[str] = None,
|
||||
license: Optional[Union[str, list[str]]] = None,
|
||||
system: Optional[str] = None,
|
||||
parameters: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
|
||||
*,
|
||||
stream: Literal[True] = True,
|
||||
) -> ProgressResponse: ...
|
||||
|
||||
@overload
|
||||
async def create(
|
||||
self,
|
||||
model: str,
|
||||
path: Optional[Union[str, PathLike]] = None,
|
||||
modelfile: Optional[str] = None,
|
||||
*,
|
||||
quantize: Optional[str] = None,
|
||||
from_: Optional[str] = None,
|
||||
files: Optional[dict[str, str]] = None,
|
||||
adapters: Optional[dict[str, str]] = None,
|
||||
template: Optional[str] = None,
|
||||
license: Optional[Union[str, list[str]]] = None,
|
||||
system: Optional[str] = None,
|
||||
parameters: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
|
||||
*,
|
||||
stream: Literal[True] = True,
|
||||
) -> AsyncIterator[ProgressResponse]: ...
|
||||
|
||||
async def create(
|
||||
self,
|
||||
model: str,
|
||||
path: Optional[Union[str, PathLike]] = None,
|
||||
modelfile: Optional[str] = None,
|
||||
*,
|
||||
quantize: Optional[str] = None,
|
||||
from_: Optional[str] = None,
|
||||
files: Optional[dict[str, str]] = None,
|
||||
adapters: Optional[dict[str, str]] = None,
|
||||
template: Optional[str] = None,
|
||||
license: Optional[Union[str, list[str]]] = None,
|
||||
system: Optional[str] = None,
|
||||
parameters: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
) -> Union[ProgressResponse, AsyncIterator[ProgressResponse]]:
|
||||
"""
|
||||
@ -1028,12 +1028,6 @@ class AsyncClient(BaseClient):
|
||||
|
||||
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
|
||||
"""
|
||||
if (realpath := _as_path(path)) and realpath.exists():
|
||||
modelfile = await self._parse_modelfile(realpath.read_text(), base=realpath.parent)
|
||||
elif modelfile:
|
||||
modelfile = await self._parse_modelfile(modelfile)
|
||||
else:
|
||||
raise RequestError('must provide either path or modelfile')
|
||||
|
||||
return await self._request(
|
||||
ProgressResponse,
|
||||
@ -1041,31 +1035,20 @@ class AsyncClient(BaseClient):
|
||||
'/api/create',
|
||||
json=CreateRequest(
|
||||
model=model,
|
||||
modelfile=modelfile,
|
||||
stream=stream,
|
||||
quantize=quantize,
|
||||
from_=from_,
|
||||
files=files,
|
||||
adapters=adapters,
|
||||
license=license,
|
||||
template=template,
|
||||
system=system,
|
||||
parameters=parameters,
|
||||
messages=messages,
|
||||
).model_dump(exclude_none=True),
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
|
||||
base = Path.cwd() if base is None else base
|
||||
|
||||
out = io.StringIO()
|
||||
for line in io.StringIO(modelfile):
|
||||
command, _, args = line.partition(' ')
|
||||
if command.upper() not in ['FROM', 'ADAPTER']:
|
||||
print(line, end='', file=out)
|
||||
continue
|
||||
|
||||
path = Path(args.strip()).expanduser()
|
||||
path = path if path.is_absolute() else base / path
|
||||
if path.exists():
|
||||
args = f'@{await self.create_blob(path)}\n'
|
||||
print(command, args, end='', file=out)
|
||||
|
||||
return out.getvalue()
|
||||
|
||||
async def create_blob(self, path: Union[str, Path]) -> str:
|
||||
sha256sum = sha256()
|
||||
with open(path, 'rb') as r:
|
||||
|
||||
@ -933,86 +933,13 @@ async def test_async_client_push_stream(httpserver: HTTPServer):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_path(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
|
||||
async def test_async_client_create_with_blob(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/create',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({'status': 'success'})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = await client.create('dummy', path=modelfile.name)
|
||||
assert response['status'] == 'success'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_path_relative(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/create',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({'status': 'success'})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
|
||||
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = await client.create('dummy', path=modelfile.name)
|
||||
assert response['status'] == 'success'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/create',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({'status': 'success'})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
|
||||
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = await client.create('dummy', path=modelfile.name)
|
||||
assert response['status'] == 'success'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_modelfile(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/create',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'files': {'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({'status': 'success'})
|
||||
@ -1020,30 +947,25 @@ async def test_async_client_create_modelfile(httpserver: HTTPServer):
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = await client.create('dummy', modelfile=f'FROM {blob.name}')
|
||||
response = await client.create('dummy', files={'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'})
|
||||
assert response['status'] == 'success'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_modelfile_roundtrip(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
|
||||
async def test_async_client_create_with_parameters_roundtrip(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/create',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
|
||||
TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>
|
||||
{{.Prompt}} [/INST]"""
|
||||
SYSTEM """
|
||||
Use
|
||||
multiline
|
||||
strings.
|
||||
"""
|
||||
PARAMETER stop [INST]
|
||||
PARAMETER stop [/INST]
|
||||
PARAMETER stop <<SYS>>
|
||||
PARAMETER stop <</SYS>>''',
|
||||
'quantize': 'q4_k_m',
|
||||
'from': 'mymodel',
|
||||
'adapters': {'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
|
||||
'template': '[INST] <<SYS>>{{.System}}<</SYS>>\n{{.Prompt}} [/INST]',
|
||||
'license': 'this is my license',
|
||||
'system': '\nUse\nmultiline\nstrings.\n',
|
||||
'parameters': {'stop': ['[INST]', '[/INST]', '<<SYS>>', '<</SYS>>'], 'pi': 3.14159},
|
||||
'messages': [{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({'status': 'success'})
|
||||
@ -1053,22 +975,15 @@ PARAMETER stop <</SYS>>''',
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = await client.create(
|
||||
'dummy',
|
||||
modelfile='\n'.join(
|
||||
[
|
||||
f'FROM {blob.name}',
|
||||
'TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>',
|
||||
'{{.Prompt}} [/INST]"""',
|
||||
'SYSTEM """',
|
||||
'Use',
|
||||
'multiline',
|
||||
'strings.',
|
||||
'"""',
|
||||
'PARAMETER stop [INST]',
|
||||
'PARAMETER stop [/INST]',
|
||||
'PARAMETER stop <<SYS>>',
|
||||
'PARAMETER stop <</SYS>>',
|
||||
]
|
||||
),
|
||||
quantize='q4_k_m',
|
||||
from_='mymodel',
|
||||
adapters={'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
|
||||
template='[INST] <<SYS>>{{.System}}<</SYS>>\n{{.Prompt}} [/INST]',
|
||||
license='this is my license',
|
||||
system='\nUse\nmultiline\nstrings.\n',
|
||||
parameters={'stop': ['[INST]', '[/INST]', '<<SYS>>', '<</SYS>>'], 'pi': 3.14159},
|
||||
messages=[{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
|
||||
stream=False,
|
||||
)
|
||||
assert response['status'] == 'success'
|
||||
|
||||
@ -1080,14 +995,14 @@ async def test_async_client_create_from_library(httpserver: HTTPServer):
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM llama2',
|
||||
'from': 'llama2',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({'status': 'success'})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
response = await client.create('dummy', modelfile='FROM llama2')
|
||||
response = await client.create('dummy', from_='llama2')
|
||||
assert response['status'] == 'success'
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user