From 33c4b61ff99bb79133d6965bf40120eda2428efa Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Thu, 18 Jul 2024 11:04:17 -0700 Subject: [PATCH] add insert support to generate endpoint (#215) * add suffix * update fill-in-the-middle example * keep example * lint * variables --- examples/fill-in-middle/main.py | 6 +++--- ollama/_client.py | 8 ++++++++ tests/test_client.py | 6 ++++++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/examples/fill-in-middle/main.py b/examples/fill-in-middle/main.py index 67d7a74..0bd2e01 100644 --- a/examples/fill-in-middle/main.py +++ b/examples/fill-in-middle/main.py @@ -1,16 +1,16 @@ from ollama import generate -prefix = '''def remove_non_ascii(s: str) -> str: +prompt = '''def remove_non_ascii(s: str) -> str: """ ''' suffix = """ return result """ - response = generate( model='codellama:7b-code', - prompt=f'
{prefix} {suffix} ',
+ prompt=prompt,
+ suffix=suffix,
options={
'num_predict': 128,
'temperature': 0,
diff --git a/ollama/_client.py b/ollama/_client.py
index e991092..de26805 100644
--- a/ollama/_client.py
+++ b/ollama/_client.py
@@ -102,6 +102,7 @@ class Client(BaseClient):
self,
model: str = '',
prompt: str = '',
+ suffix: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
@@ -118,6 +119,7 @@ class Client(BaseClient):
self,
model: str = '',
prompt: str = '',
+ suffix: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
@@ -133,6 +135,7 @@ class Client(BaseClient):
self,
model: str = '',
prompt: str = '',
+ suffix: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
@@ -162,6 +165,7 @@ class Client(BaseClient):
json={
'model': model,
'prompt': prompt,
+ 'suffix': suffix,
'system': system,
'template': template,
'context': context or [],
@@ -518,6 +522,7 @@ class AsyncClient(BaseClient):
self,
model: str = '',
prompt: str = '',
+ suffix: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
@@ -534,6 +539,7 @@ class AsyncClient(BaseClient):
self,
model: str = '',
prompt: str = '',
+ suffix: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
@@ -549,6 +555,7 @@ class AsyncClient(BaseClient):
self,
model: str = '',
prompt: str = '',
+ suffix: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
@@ -577,6 +584,7 @@ class AsyncClient(BaseClient):
json={
'model': model,
'prompt': prompt,
+ 'suffix': suffix,
'system': system,
'template': template,
'context': context or [],
diff --git a/tests/test_client.py b/tests/test_client.py
index 727c239..0b062f5 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -137,6 +137,7 @@ def test_client_generate(httpserver: HTTPServer):
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
+ 'suffix': '',
'system': '',
'template': '',
'context': [],
@@ -182,6 +183,7 @@ def test_client_generate_stream(httpserver: HTTPServer):
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
+ 'suffix': '',
'system': '',
'template': '',
'context': [],
@@ -210,6 +212,7 @@ def test_client_generate_images(httpserver: HTTPServer):
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
+ 'suffix': '',
'system': '',
'template': '',
'context': [],
@@ -619,6 +622,7 @@ async def test_async_client_generate(httpserver: HTTPServer):
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
+ 'suffix': '',
'system': '',
'template': '',
'context': [],
@@ -659,6 +663,7 @@ async def test_async_client_generate_stream(httpserver: HTTPServer):
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
+ 'suffix': '',
'system': '',
'template': '',
'context': [],
@@ -688,6 +693,7 @@ async def test_async_client_generate_images(httpserver: HTTPServer):
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
+ 'suffix': '',
'system': '',
'template': '',
'context': [],