fix parse modelfile

- do not add newlines while parsing
- do not add leading whitespace
This commit is contained in:
Michael Yang 2024-01-29 12:50:45 -08:00
parent f618a2f448
commit 8e5d431d0d
2 changed files with 117 additions and 14 deletions

View File

@ -259,13 +259,16 @@ class Client(BaseClient):
out = io.StringIO()
for line in io.StringIO(modelfile):
command, _, args = line.partition(' ')
if command.upper() in ['FROM', 'ADAPTER']:
path = Path(args.strip()).expanduser()
path = path if path.is_absolute() else base / path
if path.exists():
args = f'@{self._create_blob(path)}'
if command.upper() not in ['FROM', 'ADAPTER']:
print(line, end='', file=out)
continue
path = Path(args.strip()).expanduser()
path = path if path.is_absolute() else base / path
if path.exists():
args = f'@{self._create_blob(path)}\n'
print(command, args, end='', file=out)
print(command, args, file=out)
return out.getvalue()
def _create_blob(self, path: Union[str, Path]) -> str:
@ -527,13 +530,16 @@ class AsyncClient(BaseClient):
out = io.StringIO()
for line in io.StringIO(modelfile):
command, _, args = line.partition(' ')
if command.upper() in ['FROM', 'ADAPTER']:
path = Path(args).expanduser()
path = path if path.is_absolute() else base / path
if path.exists():
args = f'@{await self._create_blob(path)}'
if command.upper() not in ['FROM', 'ADAPTER']:
print(line, end='', file=out)
continue
path = Path(args.strip()).expanduser()
path = path if path.is_absolute() else base / path
if path.exists():
args = f'@{await self._create_blob(path)}\n'
print(command, args, end='', file=out)
print(command, args, file=out)
return out.getvalue()
async def _create_blob(self, path: Union[str, Path]) -> str:

View File

@ -416,13 +416,61 @@ def test_client_create_modelfile(httpserver: HTTPServer):
assert isinstance(response, dict)
def test_client_create_modelfile_roundtrip(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'name': 'dummy',
'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>
{{.Prompt}} [/INST]"""
SYSTEM """
Use
multiline
strings.
"""
PARAMETER stop [INST]
PARAMETER stop [/INST]
PARAMETER stop <<SYS>>
PARAMETER stop <</SYS>>''',
'stream': False,
},
).respond_with_json({})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = client.create(
'dummy',
modelfile='\n'.join(
[
f'FROM {blob.name}',
'TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>',
'{{.Prompt}} [/INST]"""',
'SYSTEM """',
'Use',
'multiline',
'strings.',
'"""',
'PARAMETER stop [INST]',
'PARAMETER stop [/INST]',
'PARAMETER stop <<SYS>>',
'PARAMETER stop <</SYS>>',
]
),
)
assert isinstance(response, dict)
def test_client_create_from_library(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'name': 'dummy',
'modelfile': 'FROM llama2\n',
'modelfile': 'FROM llama2',
'stream': False,
},
).respond_with_json({})
@ -820,6 +868,55 @@ async def test_async_client_create_modelfile(httpserver: HTTPServer):
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_create_modelfile_roundtrip(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'name': 'dummy',
'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>
{{.Prompt}} [/INST]"""
SYSTEM """
Use
multiline
strings.
"""
PARAMETER stop [INST]
PARAMETER stop [/INST]
PARAMETER stop <<SYS>>
PARAMETER stop <</SYS>>''',
'stream': False,
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = await client.create(
'dummy',
modelfile='\n'.join(
[
f'FROM {blob.name}',
'TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>',
'{{.Prompt}} [/INST]"""',
'SYSTEM """',
'Use',
'multiline',
'strings.',
'"""',
'PARAMETER stop [INST]',
'PARAMETER stop [/INST]',
'PARAMETER stop <<SYS>>',
'PARAMETER stop <</SYS>>',
]
),
)
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_create_from_library(httpserver: HTTPServer):
httpserver.expect_ordered_request(
@ -827,7 +924,7 @@ async def test_async_client_create_from_library(httpserver: HTTPServer):
method='POST',
json={
'name': 'dummy',
'modelfile': 'FROM llama2\n',
'modelfile': 'FROM llama2',
'stream': False,
},
).respond_with_json({})