add insert support to generate endpoint (#215)

* add suffix

* update fill-in-the-middle example

* keep example

* lint

* variables
This commit is contained in:
royjhan 2024-07-18 11:04:17 -07:00 committed by GitHub
parent b0ea6d9e44
commit 33c4b61ff9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 3 deletions

View File

@ -1,16 +1,16 @@
from ollama import generate from ollama import generate
prefix = '''def remove_non_ascii(s: str) -> str: prompt = '''def remove_non_ascii(s: str) -> str:
""" ''' """ '''
suffix = """ suffix = """
return result return result
""" """
response = generate( response = generate(
model='codellama:7b-code', model='codellama:7b-code',
prompt=f'<PRE> {prefix} <SUF>{suffix} <MID>', prompt=prompt,
suffix=suffix,
options={ options={
'num_predict': 128, 'num_predict': 128,
'temperature': 0, 'temperature': 0,

View File

@ -102,6 +102,7 @@ class Client(BaseClient):
self, self,
model: str = '', model: str = '',
prompt: str = '', prompt: str = '',
suffix: str = '',
system: str = '', system: str = '',
template: str = '', template: str = '',
context: Optional[Sequence[int]] = None, context: Optional[Sequence[int]] = None,
@ -118,6 +119,7 @@ class Client(BaseClient):
self, self,
model: str = '', model: str = '',
prompt: str = '', prompt: str = '',
suffix: str = '',
system: str = '', system: str = '',
template: str = '', template: str = '',
context: Optional[Sequence[int]] = None, context: Optional[Sequence[int]] = None,
@ -133,6 +135,7 @@ class Client(BaseClient):
self, self,
model: str = '', model: str = '',
prompt: str = '', prompt: str = '',
suffix: str = '',
system: str = '', system: str = '',
template: str = '', template: str = '',
context: Optional[Sequence[int]] = None, context: Optional[Sequence[int]] = None,
@ -162,6 +165,7 @@ class Client(BaseClient):
json={ json={
'model': model, 'model': model,
'prompt': prompt, 'prompt': prompt,
'suffix': suffix,
'system': system, 'system': system,
'template': template, 'template': template,
'context': context or [], 'context': context or [],
@ -518,6 +522,7 @@ class AsyncClient(BaseClient):
self, self,
model: str = '', model: str = '',
prompt: str = '', prompt: str = '',
suffix: str = '',
system: str = '', system: str = '',
template: str = '', template: str = '',
context: Optional[Sequence[int]] = None, context: Optional[Sequence[int]] = None,
@ -534,6 +539,7 @@ class AsyncClient(BaseClient):
self, self,
model: str = '', model: str = '',
prompt: str = '', prompt: str = '',
suffix: str = '',
system: str = '', system: str = '',
template: str = '', template: str = '',
context: Optional[Sequence[int]] = None, context: Optional[Sequence[int]] = None,
@ -549,6 +555,7 @@ class AsyncClient(BaseClient):
self, self,
model: str = '', model: str = '',
prompt: str = '', prompt: str = '',
suffix: str = '',
system: str = '', system: str = '',
template: str = '', template: str = '',
context: Optional[Sequence[int]] = None, context: Optional[Sequence[int]] = None,
@ -577,6 +584,7 @@ class AsyncClient(BaseClient):
json={ json={
'model': model, 'model': model,
'prompt': prompt, 'prompt': prompt,
'suffix': suffix,
'system': system, 'system': system,
'template': template, 'template': template,
'context': context or [], 'context': context or [],

View File

@ -137,6 +137,7 @@ def test_client_generate(httpserver: HTTPServer):
json={ json={
'model': 'dummy', 'model': 'dummy',
'prompt': 'Why is the sky blue?', 'prompt': 'Why is the sky blue?',
'suffix': '',
'system': '', 'system': '',
'template': '', 'template': '',
'context': [], 'context': [],
@ -182,6 +183,7 @@ def test_client_generate_stream(httpserver: HTTPServer):
json={ json={
'model': 'dummy', 'model': 'dummy',
'prompt': 'Why is the sky blue?', 'prompt': 'Why is the sky blue?',
'suffix': '',
'system': '', 'system': '',
'template': '', 'template': '',
'context': [], 'context': [],
@ -210,6 +212,7 @@ def test_client_generate_images(httpserver: HTTPServer):
json={ json={
'model': 'dummy', 'model': 'dummy',
'prompt': 'Why is the sky blue?', 'prompt': 'Why is the sky blue?',
'suffix': '',
'system': '', 'system': '',
'template': '', 'template': '',
'context': [], 'context': [],
@ -619,6 +622,7 @@ async def test_async_client_generate(httpserver: HTTPServer):
json={ json={
'model': 'dummy', 'model': 'dummy',
'prompt': 'Why is the sky blue?', 'prompt': 'Why is the sky blue?',
'suffix': '',
'system': '', 'system': '',
'template': '', 'template': '',
'context': [], 'context': [],
@ -659,6 +663,7 @@ async def test_async_client_generate_stream(httpserver: HTTPServer):
json={ json={
'model': 'dummy', 'model': 'dummy',
'prompt': 'Why is the sky blue?', 'prompt': 'Why is the sky blue?',
'suffix': '',
'system': '', 'system': '',
'template': '', 'template': '',
'context': [], 'context': [],
@ -688,6 +693,7 @@ async def test_async_client_generate_images(httpserver: HTTPServer):
json={ json={
'model': 'dummy', 'model': 'dummy',
'prompt': 'Why is the sky blue?', 'prompt': 'Why is the sky blue?',
'suffix': '',
'system': '', 'system': '',
'template': '', 'template': '',
'context': [], 'context': [],