diff --git a/examples/fill-in-middle/main.py b/examples/fill-in-middle/main.py index 67d7a74..0bd2e01 100644 --- a/examples/fill-in-middle/main.py +++ b/examples/fill-in-middle/main.py @@ -1,16 +1,16 @@ from ollama import generate -prefix = '''def remove_non_ascii(s: str) -> str: +prompt = '''def remove_non_ascii(s: str) -> str: """ ''' suffix = """ return result """ - response = generate( model='codellama:7b-code', - prompt=f'
 {prefix} {suffix} ',
+  prompt=prompt,
+  suffix=suffix,
   options={
     'num_predict': 128,
     'temperature': 0,
diff --git a/ollama/_client.py b/ollama/_client.py
index e991092..de26805 100644
--- a/ollama/_client.py
+++ b/ollama/_client.py
@@ -102,6 +102,7 @@ class Client(BaseClient):
     self,
     model: str = '',
     prompt: str = '',
+    suffix: str = '',
     system: str = '',
     template: str = '',
     context: Optional[Sequence[int]] = None,
@@ -118,6 +119,7 @@ class Client(BaseClient):
     self,
     model: str = '',
     prompt: str = '',
+    suffix: str = '',
     system: str = '',
     template: str = '',
     context: Optional[Sequence[int]] = None,
@@ -133,6 +135,7 @@ class Client(BaseClient):
     self,
     model: str = '',
     prompt: str = '',
+    suffix: str = '',
     system: str = '',
     template: str = '',
     context: Optional[Sequence[int]] = None,
@@ -162,6 +165,7 @@ class Client(BaseClient):
       json={
         'model': model,
         'prompt': prompt,
+        'suffix': suffix,
         'system': system,
         'template': template,
         'context': context or [],
@@ -518,6 +522,7 @@ class AsyncClient(BaseClient):
     self,
     model: str = '',
     prompt: str = '',
+    suffix: str = '',
     system: str = '',
     template: str = '',
     context: Optional[Sequence[int]] = None,
@@ -534,6 +539,7 @@ class AsyncClient(BaseClient):
     self,
     model: str = '',
     prompt: str = '',
+    suffix: str = '',
     system: str = '',
     template: str = '',
     context: Optional[Sequence[int]] = None,
@@ -549,6 +555,7 @@ class AsyncClient(BaseClient):
     self,
     model: str = '',
     prompt: str = '',
+    suffix: str = '',
     system: str = '',
     template: str = '',
     context: Optional[Sequence[int]] = None,
@@ -577,6 +584,7 @@ class AsyncClient(BaseClient):
       json={
         'model': model,
         'prompt': prompt,
+        'suffix': suffix,
         'system': system,
         'template': template,
         'context': context or [],
diff --git a/tests/test_client.py b/tests/test_client.py
index 727c239..0b062f5 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -137,6 +137,7 @@ def test_client_generate(httpserver: HTTPServer):
     json={
       'model': 'dummy',
       'prompt': 'Why is the sky blue?',
+      'suffix': '',
       'system': '',
       'template': '',
       'context': [],
@@ -182,6 +183,7 @@ def test_client_generate_stream(httpserver: HTTPServer):
     json={
       'model': 'dummy',
       'prompt': 'Why is the sky blue?',
+      'suffix': '',
       'system': '',
       'template': '',
       'context': [],
@@ -210,6 +212,7 @@ def test_client_generate_images(httpserver: HTTPServer):
     json={
       'model': 'dummy',
       'prompt': 'Why is the sky blue?',
+      'suffix': '',
       'system': '',
       'template': '',
       'context': [],
@@ -619,6 +622,7 @@ async def test_async_client_generate(httpserver: HTTPServer):
     json={
       'model': 'dummy',
       'prompt': 'Why is the sky blue?',
+      'suffix': '',
       'system': '',
       'template': '',
       'context': [],
@@ -659,6 +663,7 @@ async def test_async_client_generate_stream(httpserver: HTTPServer):
     json={
       'model': 'dummy',
       'prompt': 'Why is the sky blue?',
+      'suffix': '',
       'system': '',
       'template': '',
       'context': [],
@@ -688,6 +693,7 @@ async def test_async_client_generate_images(httpserver: HTTPServer):
     json={
       'model': 'dummy',
       'prompt': 'Why is the sky blue?',
+      'suffix': '',
       'system': '',
       'template': '',
       'context': [],