add unit tests

This commit is contained in:
Patrick Devine 2025-01-13 14:53:10 -08:00
parent 4f9fb88137
commit 4dec73e8be
3 changed files with 28 additions and 111 deletions

View File

@ -526,9 +526,6 @@ class Client(BaseClient):
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator. Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
""" """
#if from_ == None and files == None:
# raise RequestError('neither ''from'' or ''files'' was specified')
return self._request( return self._request(
ProgressResponse, ProgressResponse,
'POST', 'POST',
@ -541,6 +538,7 @@ class Client(BaseClient):
files=files, files=files,
adapters=adapters, adapters=adapters,
license=license, license=license,
template=template,
system=system, system=system,
parameters=parameters, parameters=parameters,
messages=messages, messages=messages,

View File

@ -536,51 +536,6 @@ def test_client_push_stream(httpserver: HTTPServer):
assert part['status'] == next(it) assert part['status'] == next(it)
def test_client_create_path(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({'status': 'success'})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile() as blob:
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
modelfile.flush()
response = client.create('dummy', path=modelfile.name)
assert response['status'] == 'success'
def test_client_create_path_relative(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({'status': 'success'})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()
response = client.create('dummy', path=modelfile.name)
assert response['status'] == 'success'
@pytest.fixture @pytest.fixture
def userhomedir(): def userhomedir():
@ -591,37 +546,13 @@ def userhomedir():
os.environ['HOME'] = home os.environ['HOME'] = home
def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir): def test_client_create_with_blob(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
httpserver.expect_ordered_request( httpserver.expect_ordered_request(
'/api/create', '/api/create',
method='POST', method='POST',
json={ json={
'model': 'dummy', 'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', 'files': {'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
'stream': False,
},
).respond_with_json({'status': 'success'})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()
response = client.create('dummy', path=modelfile.name)
assert response['status'] == 'success'
def test_client_create_modelfile(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False, 'stream': False,
}, },
).respond_with_json({'status': 'success'}) ).respond_with_json({'status': 'success'})
@ -629,29 +560,24 @@ def test_client_create_modelfile(httpserver: HTTPServer):
client = Client(httpserver.url_for('/')) client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob: with tempfile.NamedTemporaryFile() as blob:
response = client.create('dummy', modelfile=f'FROM {blob.name}') response = client.create('dummy', files={'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'})
assert response['status'] == 'success' assert response['status'] == 'success'
def test_client_create_modelfile_roundtrip(httpserver: HTTPServer): def test_client_create_with_parameters_roundtrip(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
httpserver.expect_ordered_request( httpserver.expect_ordered_request(
'/api/create', '/api/create',
method='POST', method='POST',
json={ json={
'model': 'dummy', 'model': 'dummy',
'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 'quantize': 'q4_k_m',
TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>> 'from': 'mymodel',
{{.Prompt}} [/INST]""" 'adapters': {'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
SYSTEM """ 'template': '[INST] <<SYS>>{{.System}}<</SYS>>\n{{.Prompt}} [/INST]',
Use 'license': 'this is my license',
multiline 'system': '\nUse\nmultiline\nstrings.\n',
strings. 'parameters': {'stop': ['[INST]', '[/INST]', '<<SYS>>', '<</SYS>>'], 'pi': 3.14159},
""" 'messages': [{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
PARAMETER stop [INST]
PARAMETER stop [/INST]
PARAMETER stop <<SYS>>
PARAMETER stop <</SYS>>''',
'stream': False, 'stream': False,
}, },
).respond_with_json({'status': 'success'}) ).respond_with_json({'status': 'success'})
@ -661,22 +587,15 @@ PARAMETER stop <</SYS>>''',
with tempfile.NamedTemporaryFile() as blob: with tempfile.NamedTemporaryFile() as blob:
response = client.create( response = client.create(
'dummy', 'dummy',
modelfile='\n'.join( quantize='q4_k_m',
[ from_='mymodel',
f'FROM {blob.name}', adapters={'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
'TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>', template='[INST] <<SYS>>{{.System}}<</SYS>>\n{{.Prompt}} [/INST]',
'{{.Prompt}} [/INST]"""', license='this is my license',
'SYSTEM """', system='\nUse\nmultiline\nstrings.\n',
'Use', parameters={'stop': ['[INST]', '[/INST]', '<<SYS>>', '<</SYS>>'], 'pi': 3.14159},
'multiline', messages=[{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
'strings.', stream=False,
'"""',
'PARAMETER stop [INST]',
'PARAMETER stop [/INST]',
'PARAMETER stop <<SYS>>',
'PARAMETER stop <</SYS>>',
]
),
) )
assert response['status'] == 'success' assert response['status'] == 'success'
@ -687,14 +606,14 @@ def test_client_create_from_library(httpserver: HTTPServer):
method='POST', method='POST',
json={ json={
'model': 'dummy', 'model': 'dummy',
'modelfile': 'FROM llama2', 'from': 'llama2',
'stream': False, 'stream': False,
}, },
).respond_with_json({'status': 'success'}) ).respond_with_json({'status': 'success'})
client = Client(httpserver.url_for('/')) client = Client(httpserver.url_for('/'))
response = client.create('dummy', modelfile='FROM llama2') response = client.create('dummy', from_='llama2')
assert response['status'] == 'success' assert response['status'] == 'success'

View File

@ -68,10 +68,10 @@ def test_create_request_serialization():
system="test system", system="test system",
parameters={"param1": "value1"} parameters={"param1": "value1"}
) )
serialized = request.model_dump() serialized = request.model_dump()
assert serialized["from"] == "base-model" assert serialized["from"] == "base-model"
assert "from_" not in serialized assert "from_" not in serialized
assert serialized["quantize"] == "q4_0" assert serialized["quantize"] == "q4_0"
assert serialized["files"] == {"file1": "content1"} assert serialized["files"] == {"file1": "content1"}
assert serialized["adapters"] == {"adapter1": "content1"} assert serialized["adapters"] == {"adapter1": "content1"}
@ -89,7 +89,7 @@ def test_create_request_serialization_exclude_none_true():
quantize=None quantize=None
) )
serialized = request.model_dump(exclude_none=True) serialized = request.model_dump(exclude_none=True)
assert serialized == {"model": "test-model"} assert serialized == {"model": "test-model"}
assert "from" not in serialized assert "from" not in serialized
assert "from_" not in serialized assert "from_" not in serialized
assert "quantize" not in serialized assert "quantize" not in serialized