mirror of
https://github.com/ollama/ollama-python.git
synced 2026-01-13 21:57:16 +08:00
add unit tests
This commit is contained in:
parent
4f9fb88137
commit
4dec73e8be
@ -526,9 +526,6 @@ class Client(BaseClient):
|
|||||||
|
|
||||||
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
|
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
|
||||||
"""
|
"""
|
||||||
#if from_ == None and files == None:
|
|
||||||
# raise RequestError('neither ''from'' or ''files'' was specified')
|
|
||||||
|
|
||||||
return self._request(
|
return self._request(
|
||||||
ProgressResponse,
|
ProgressResponse,
|
||||||
'POST',
|
'POST',
|
||||||
@ -541,6 +538,7 @@ class Client(BaseClient):
|
|||||||
files=files,
|
files=files,
|
||||||
adapters=adapters,
|
adapters=adapters,
|
||||||
license=license,
|
license=license,
|
||||||
|
template=template,
|
||||||
system=system,
|
system=system,
|
||||||
parameters=parameters,
|
parameters=parameters,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|||||||
@ -536,51 +536,6 @@ def test_client_push_stream(httpserver: HTTPServer):
|
|||||||
assert part['status'] == next(it)
|
assert part['status'] == next(it)
|
||||||
|
|
||||||
|
|
||||||
def test_client_create_path(httpserver: HTTPServer):
|
|
||||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
|
|
||||||
httpserver.expect_ordered_request(
|
|
||||||
'/api/create',
|
|
||||||
method='POST',
|
|
||||||
json={
|
|
||||||
'model': 'dummy',
|
|
||||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
|
||||||
'stream': False,
|
|
||||||
},
|
|
||||||
).respond_with_json({'status': 'success'})
|
|
||||||
|
|
||||||
client = Client(httpserver.url_for('/'))
|
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile() as modelfile:
|
|
||||||
with tempfile.NamedTemporaryFile() as blob:
|
|
||||||
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
|
|
||||||
modelfile.flush()
|
|
||||||
|
|
||||||
response = client.create('dummy', path=modelfile.name)
|
|
||||||
assert response['status'] == 'success'
|
|
||||||
|
|
||||||
|
|
||||||
def test_client_create_path_relative(httpserver: HTTPServer):
|
|
||||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
|
|
||||||
httpserver.expect_ordered_request(
|
|
||||||
'/api/create',
|
|
||||||
method='POST',
|
|
||||||
json={
|
|
||||||
'model': 'dummy',
|
|
||||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
|
||||||
'stream': False,
|
|
||||||
},
|
|
||||||
).respond_with_json({'status': 'success'})
|
|
||||||
|
|
||||||
client = Client(httpserver.url_for('/'))
|
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile() as modelfile:
|
|
||||||
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
|
|
||||||
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
|
|
||||||
modelfile.flush()
|
|
||||||
|
|
||||||
response = client.create('dummy', path=modelfile.name)
|
|
||||||
assert response['status'] == 'success'
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def userhomedir():
|
def userhomedir():
|
||||||
@ -591,37 +546,13 @@ def userhomedir():
|
|||||||
os.environ['HOME'] = home
|
os.environ['HOME'] = home
|
||||||
|
|
||||||
|
|
||||||
def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
|
def test_client_create_with_blob(httpserver: HTTPServer):
|
||||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
|
|
||||||
httpserver.expect_ordered_request(
|
httpserver.expect_ordered_request(
|
||||||
'/api/create',
|
'/api/create',
|
||||||
method='POST',
|
method='POST',
|
||||||
json={
|
json={
|
||||||
'model': 'dummy',
|
'model': 'dummy',
|
||||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
'files': {'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
|
||||||
'stream': False,
|
|
||||||
},
|
|
||||||
).respond_with_json({'status': 'success'})
|
|
||||||
|
|
||||||
client = Client(httpserver.url_for('/'))
|
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile() as modelfile:
|
|
||||||
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
|
|
||||||
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
|
|
||||||
modelfile.flush()
|
|
||||||
|
|
||||||
response = client.create('dummy', path=modelfile.name)
|
|
||||||
assert response['status'] == 'success'
|
|
||||||
|
|
||||||
|
|
||||||
def test_client_create_modelfile(httpserver: HTTPServer):
|
|
||||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
|
|
||||||
httpserver.expect_ordered_request(
|
|
||||||
'/api/create',
|
|
||||||
method='POST',
|
|
||||||
json={
|
|
||||||
'model': 'dummy',
|
|
||||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
|
||||||
'stream': False,
|
'stream': False,
|
||||||
},
|
},
|
||||||
).respond_with_json({'status': 'success'})
|
).respond_with_json({'status': 'success'})
|
||||||
@ -629,29 +560,24 @@ def test_client_create_modelfile(httpserver: HTTPServer):
|
|||||||
client = Client(httpserver.url_for('/'))
|
client = Client(httpserver.url_for('/'))
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile() as blob:
|
with tempfile.NamedTemporaryFile() as blob:
|
||||||
response = client.create('dummy', modelfile=f'FROM {blob.name}')
|
response = client.create('dummy', files={'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'})
|
||||||
assert response['status'] == 'success'
|
assert response['status'] == 'success'
|
||||||
|
|
||||||
|
|
||||||
def test_client_create_modelfile_roundtrip(httpserver: HTTPServer):
|
def test_client_create_with_parameters_roundtrip(httpserver: HTTPServer):
|
||||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
|
|
||||||
httpserver.expect_ordered_request(
|
httpserver.expect_ordered_request(
|
||||||
'/api/create',
|
'/api/create',
|
||||||
method='POST',
|
method='POST',
|
||||||
json={
|
json={
|
||||||
'model': 'dummy',
|
'model': 'dummy',
|
||||||
'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
|
'quantize': 'q4_k_m',
|
||||||
TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>
|
'from': 'mymodel',
|
||||||
{{.Prompt}} [/INST]"""
|
'adapters': {'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
|
||||||
SYSTEM """
|
'template': '[INST] <<SYS>>{{.System}}<</SYS>>\n{{.Prompt}} [/INST]',
|
||||||
Use
|
'license': 'this is my license',
|
||||||
multiline
|
'system': '\nUse\nmultiline\nstrings.\n',
|
||||||
strings.
|
'parameters': {'stop': ['[INST]', '[/INST]', '<<SYS>>', '<</SYS>>'], 'pi': 3.14159},
|
||||||
"""
|
'messages': [{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
|
||||||
PARAMETER stop [INST]
|
|
||||||
PARAMETER stop [/INST]
|
|
||||||
PARAMETER stop <<SYS>>
|
|
||||||
PARAMETER stop <</SYS>>''',
|
|
||||||
'stream': False,
|
'stream': False,
|
||||||
},
|
},
|
||||||
).respond_with_json({'status': 'success'})
|
).respond_with_json({'status': 'success'})
|
||||||
@ -661,22 +587,15 @@ PARAMETER stop <</SYS>>''',
|
|||||||
with tempfile.NamedTemporaryFile() as blob:
|
with tempfile.NamedTemporaryFile() as blob:
|
||||||
response = client.create(
|
response = client.create(
|
||||||
'dummy',
|
'dummy',
|
||||||
modelfile='\n'.join(
|
quantize='q4_k_m',
|
||||||
[
|
from_='mymodel',
|
||||||
f'FROM {blob.name}',
|
adapters={'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
|
||||||
'TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>',
|
template='[INST] <<SYS>>{{.System}}<</SYS>>\n{{.Prompt}} [/INST]',
|
||||||
'{{.Prompt}} [/INST]"""',
|
license='this is my license',
|
||||||
'SYSTEM """',
|
system='\nUse\nmultiline\nstrings.\n',
|
||||||
'Use',
|
parameters={'stop': ['[INST]', '[/INST]', '<<SYS>>', '<</SYS>>'], 'pi': 3.14159},
|
||||||
'multiline',
|
messages=[{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
|
||||||
'strings.',
|
stream=False,
|
||||||
'"""',
|
|
||||||
'PARAMETER stop [INST]',
|
|
||||||
'PARAMETER stop [/INST]',
|
|
||||||
'PARAMETER stop <<SYS>>',
|
|
||||||
'PARAMETER stop <</SYS>>',
|
|
||||||
]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
assert response['status'] == 'success'
|
assert response['status'] == 'success'
|
||||||
|
|
||||||
@ -687,14 +606,14 @@ def test_client_create_from_library(httpserver: HTTPServer):
|
|||||||
method='POST',
|
method='POST',
|
||||||
json={
|
json={
|
||||||
'model': 'dummy',
|
'model': 'dummy',
|
||||||
'modelfile': 'FROM llama2',
|
'from': 'llama2',
|
||||||
'stream': False,
|
'stream': False,
|
||||||
},
|
},
|
||||||
).respond_with_json({'status': 'success'})
|
).respond_with_json({'status': 'success'})
|
||||||
|
|
||||||
client = Client(httpserver.url_for('/'))
|
client = Client(httpserver.url_for('/'))
|
||||||
|
|
||||||
response = client.create('dummy', modelfile='FROM llama2')
|
response = client.create('dummy', from_='llama2')
|
||||||
assert response['status'] == 'success'
|
assert response['status'] == 'success'
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -68,10 +68,10 @@ def test_create_request_serialization():
|
|||||||
system="test system",
|
system="test system",
|
||||||
parameters={"param1": "value1"}
|
parameters={"param1": "value1"}
|
||||||
)
|
)
|
||||||
|
|
||||||
serialized = request.model_dump()
|
serialized = request.model_dump()
|
||||||
assert serialized["from"] == "base-model"
|
assert serialized["from"] == "base-model"
|
||||||
assert "from_" not in serialized
|
assert "from_" not in serialized
|
||||||
assert serialized["quantize"] == "q4_0"
|
assert serialized["quantize"] == "q4_0"
|
||||||
assert serialized["files"] == {"file1": "content1"}
|
assert serialized["files"] == {"file1": "content1"}
|
||||||
assert serialized["adapters"] == {"adapter1": "content1"}
|
assert serialized["adapters"] == {"adapter1": "content1"}
|
||||||
@ -89,7 +89,7 @@ def test_create_request_serialization_exclude_none_true():
|
|||||||
quantize=None
|
quantize=None
|
||||||
)
|
)
|
||||||
serialized = request.model_dump(exclude_none=True)
|
serialized = request.model_dump(exclude_none=True)
|
||||||
assert serialized == {"model": "test-model"}
|
assert serialized == {"model": "test-model"}
|
||||||
assert "from" not in serialized
|
assert "from" not in serialized
|
||||||
assert "from_" not in serialized
|
assert "from_" not in serialized
|
||||||
assert "quantize" not in serialized
|
assert "quantize" not in serialized
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user