Add type overloads to methods (#181)

* Add type overloads for chat() method in _client.py

* Overloading

* Fix Overload Overlap

* Fix async chat

* Lint

* Reverse

---------

Co-authored-by: Simon Ottenhaus <simon.ottenhaus@kenbun.de>
This commit is contained in:
royjhan 2024-06-19 16:10:44 -07:00 committed by GitHub
parent 982d65fea0
commit ce56f279e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,7 +11,7 @@ from copy import deepcopy
from hashlib import sha256
from base64 import b64encode, b64decode
from typing import Any, AnyStr, Union, Optional, Sequence, Mapping, Literal
from typing import Any, AnyStr, Union, Optional, Sequence, Mapping, Literal, overload
import sys
@ -97,6 +97,38 @@ class Client(BaseClient):
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
return self._stream(*args, **kwargs) if stream else self._request(*args, **kwargs).json()
@overload
def generate(
self,
model: str = '',
prompt: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
stream: Literal[False] = False,
raw: bool = False,
format: Literal['', 'json'] = '',
images: Optional[Sequence[AnyStr]] = None,
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
model: str = '',
prompt: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
stream: Literal[True] = True,
raw: bool = False,
format: Literal['', 'json'] = '',
images: Optional[Sequence[AnyStr]] = None,
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Iterator[Mapping[str, Any]]: ...
def generate(
self,
model: str = '',
@ -143,6 +175,28 @@ class Client(BaseClient):
stream=stream,
)
@overload
def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
stream: Literal[False] = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Mapping[str, Any]: ...
@overload
def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
stream: Literal[True] = True,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Iterator[Mapping[str, Any]]: ...
def chat(
self,
model: str = '',
@ -209,6 +263,22 @@ class Client(BaseClient):
},
).json()
@overload
def pull(
self,
model: str,
insecure: bool = False,
stream: Literal[False] = False,
) -> Mapping[str, Any]: ...
@overload
def pull(
self,
model: str,
insecure: bool = False,
stream: Literal[True] = True,
) -> Iterator[Mapping[str, Any]]: ...
def pull(
self,
model: str,
@ -231,6 +301,22 @@ class Client(BaseClient):
stream=stream,
)
@overload
def push(
self,
model: str,
insecure: bool = False,
stream: Literal[False] = False,
) -> Mapping[str, Any]: ...
@overload
def push(
self,
model: str,
insecure: bool = False,
stream: Literal[True] = True,
) -> Iterator[Mapping[str, Any]]: ...
def push(
self,
model: str,
@ -253,6 +339,26 @@ class Client(BaseClient):
stream=stream,
)
@overload
def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
quantize: Optional[str] = None,
stream: Literal[False] = False,
) -> Mapping[str, Any]: ...
@overload
def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
quantize: Optional[str] = None,
stream: Literal[True] = True,
) -> Iterator[Mapping[str, Any]]: ...
def create(
self,
model: str,
@ -386,6 +492,38 @@ class AsyncClient(BaseClient):
response = await self._request(*args, **kwargs)
return response.json()
@overload
async def generate(
self,
model: str = '',
prompt: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
stream: Literal[False] = False,
raw: bool = False,
format: Literal['', 'json'] = '',
images: Optional[Sequence[AnyStr]] = None,
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Mapping[str, Any]: ...
@overload
async def generate(
self,
model: str = '',
prompt: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
stream: Literal[True] = True,
raw: bool = False,
format: Literal['', 'json'] = '',
images: Optional[Sequence[AnyStr]] = None,
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> AsyncIterator[Mapping[str, Any]]: ...
async def generate(
self,
model: str = '',
@ -431,6 +569,28 @@ class AsyncClient(BaseClient):
stream=stream,
)
@overload
async def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
stream: Literal[False] = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Mapping[str, Any]: ...
@overload
async def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
stream: Literal[True] = True,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> AsyncIterator[Mapping[str, Any]]: ...
async def chat(
self,
model: str = '',
@ -498,6 +658,22 @@ class AsyncClient(BaseClient):
return response.json()
@overload
async def pull(
self,
model: str,
insecure: bool = False,
stream: Literal[False] = False,
) -> Mapping[str, Any]: ...
@overload
async def pull(
self,
model: str,
insecure: bool = False,
stream: Literal[True] = True,
) -> AsyncIterator[Mapping[str, Any]]: ...
async def pull(
self,
model: str,
@ -520,6 +696,22 @@ class AsyncClient(BaseClient):
stream=stream,
)
@overload
async def push(
self,
model: str,
insecure: bool = False,
stream: Literal[False] = False,
) -> Mapping[str, Any]: ...
@overload
async def push(
self,
model: str,
insecure: bool = False,
stream: Literal[True] = True,
) -> AsyncIterator[Mapping[str, Any]]: ...
async def push(
self,
model: str,
@ -542,6 +734,26 @@ class AsyncClient(BaseClient):
stream=stream,
)
@overload
async def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
quantize: Optional[str] = None,
stream: Literal[False] = False,
) -> Mapping[str, Any]: ...
@overload
async def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
quantize: Optional[str] = None,
stream: Literal[True] = True,
) -> AsyncIterator[Mapping[str, Any]]: ...
async def create(
self,
model: str,