From ce56f279e84374f0c65db466fd2e54a4a664fb07 Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Wed, 19 Jun 2024 16:10:44 -0700 Subject: [PATCH] Add type overloads to methods (#181) * Add type overloads for chat() method in _client.py * Overloading * Fix Overload Overlap * Fix async chat * Lint * Reverse --------- Co-authored-by: Simon Ottenhaus --- ollama/_client.py | 214 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 213 insertions(+), 1 deletion(-) diff --git a/ollama/_client.py b/ollama/_client.py index 7b55f29..1109aee 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -11,7 +11,7 @@ from copy import deepcopy from hashlib import sha256 from base64 import b64encode, b64decode -from typing import Any, AnyStr, Union, Optional, Sequence, Mapping, Literal +from typing import Any, AnyStr, Union, Optional, Sequence, Mapping, Literal, overload import sys @@ -97,6 +97,38 @@ class Client(BaseClient): ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: return self._stream(*args, **kwargs) if stream else self._request(*args, **kwargs).json() + @overload + def generate( + self, + model: str = '', + prompt: str = '', + system: str = '', + template: str = '', + context: Optional[Sequence[int]] = None, + stream: Literal[False] = False, + raw: bool = False, + format: Literal['', 'json'] = '', + images: Optional[Sequence[AnyStr]] = None, + options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, + ) -> Mapping[str, Any]: ... + + @overload + def generate( + self, + model: str = '', + prompt: str = '', + system: str = '', + template: str = '', + context: Optional[Sequence[int]] = None, + stream: Literal[True] = True, + raw: bool = False, + format: Literal['', 'json'] = '', + images: Optional[Sequence[AnyStr]] = None, + options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, + ) -> Iterator[Mapping[str, Any]]: ... + def generate( self, model: str = '', @@ -143,6 +175,28 @@ class Client(BaseClient): stream=stream, ) + @overload + def chat( + self, + model: str = '', + messages: Optional[Sequence[Message]] = None, + stream: Literal[False] = False, + format: Literal['', 'json'] = '', + options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, + ) -> Mapping[str, Any]: ... + + @overload + def chat( + self, + model: str = '', + messages: Optional[Sequence[Message]] = None, + stream: Literal[True] = True, + format: Literal['', 'json'] = '', + options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, + ) -> Iterator[Mapping[str, Any]]: ... + def chat( self, model: str = '', @@ -209,6 +263,22 @@ class Client(BaseClient): }, ).json() + @overload + def pull( + self, + model: str, + insecure: bool = False, + stream: Literal[False] = False, + ) -> Mapping[str, Any]: ... + + @overload + def pull( + self, + model: str, + insecure: bool = False, + stream: Literal[True] = True, + ) -> Iterator[Mapping[str, Any]]: ... + def pull( self, model: str, @@ -231,6 +301,22 @@ class Client(BaseClient): stream=stream, ) + @overload + def push( + self, + model: str, + insecure: bool = False, + stream: Literal[False] = False, + ) -> Mapping[str, Any]: ... + + @overload + def push( + self, + model: str, + insecure: bool = False, + stream: Literal[True] = True, + ) -> Iterator[Mapping[str, Any]]: ... + def push( self, model: str, @@ -253,6 +339,26 @@ class Client(BaseClient): stream=stream, ) + @overload + def create( + self, + model: str, + path: Optional[Union[str, PathLike]] = None, + modelfile: Optional[str] = None, + quantize: Optional[str] = None, + stream: Literal[False] = False, + ) -> Mapping[str, Any]: ... + + @overload + def create( + self, + model: str, + path: Optional[Union[str, PathLike]] = None, + modelfile: Optional[str] = None, + quantize: Optional[str] = None, + stream: Literal[True] = True, + ) -> Iterator[Mapping[str, Any]]: ... + def create( self, model: str, @@ -386,6 +492,38 @@ class AsyncClient(BaseClient): response = await self._request(*args, **kwargs) return response.json() + @overload + async def generate( + self, + model: str = '', + prompt: str = '', + system: str = '', + template: str = '', + context: Optional[Sequence[int]] = None, + stream: Literal[False] = False, + raw: bool = False, + format: Literal['', 'json'] = '', + images: Optional[Sequence[AnyStr]] = None, + options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, + ) -> Mapping[str, Any]: ... + + @overload + async def generate( + self, + model: str = '', + prompt: str = '', + system: str = '', + template: str = '', + context: Optional[Sequence[int]] = None, + stream: Literal[True] = True, + raw: bool = False, + format: Literal['', 'json'] = '', + images: Optional[Sequence[AnyStr]] = None, + options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, + ) -> AsyncIterator[Mapping[str, Any]]: ... + async def generate( self, model: str = '', @@ -431,6 +569,28 @@ class AsyncClient(BaseClient): stream=stream, ) + @overload + async def chat( + self, + model: str = '', + messages: Optional[Sequence[Message]] = None, + stream: Literal[False] = False, + format: Literal['', 'json'] = '', + options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, + ) -> Mapping[str, Any]: ... + + @overload + async def chat( + self, + model: str = '', + messages: Optional[Sequence[Message]] = None, + stream: Literal[True] = True, + format: Literal['', 'json'] = '', + options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, + ) -> AsyncIterator[Mapping[str, Any]]: ... + async def chat( self, model: str = '', @@ -498,6 +658,22 @@ class AsyncClient(BaseClient): return response.json() + @overload + async def pull( + self, + model: str, + insecure: bool = False, + stream: Literal[False] = False, + ) -> Mapping[str, Any]: ... + + @overload + async def pull( + self, + model: str, + insecure: bool = False, + stream: Literal[True] = True, + ) -> AsyncIterator[Mapping[str, Any]]: ... + async def pull( self, model: str, @@ -520,6 +696,22 @@ class AsyncClient(BaseClient): stream=stream, ) + @overload + async def push( + self, + model: str, + insecure: bool = False, + stream: Literal[False] = False, + ) -> Mapping[str, Any]: ... + + @overload + async def push( + self, + model: str, + insecure: bool = False, + stream: Literal[True] = True, + ) -> AsyncIterator[Mapping[str, Any]]: ... + async def push( self, model: str, @@ -542,6 +734,26 @@ class AsyncClient(BaseClient): stream=stream, ) + @overload + async def create( + self, + model: str, + path: Optional[Union[str, PathLike]] = None, + modelfile: Optional[str] = None, + quantize: Optional[str] = None, + stream: Literal[False] = False, + ) -> Mapping[str, Any]: ... + + @overload + async def create( + self, + model: str, + path: Optional[Union[str, PathLike]] = None, + modelfile: Optional[str] = None, + quantize: Optional[str] = None, + stream: Literal[True] = True, + ) -> AsyncIterator[Mapping[str, Any]]: ... + async def create( self, model: str,