Compare commits

..

9 Commits

Author SHA1 Message Date
Parth Sareen d6528cf731 Fix image serialization for long image string (#348) 2024-11-28 14:12:55 -08:00
Julia Scheaffer b50a65b27d Add Callable type annotation for Tools (#344) 2024-11-27 09:53:26 -08:00
Jeffrey Morgan 758a1d2933 make subscription methods more consistent 2024-11-26 10:52:45 -08:00
Jeffrey Morgan d4c38978d1 Update ollama/_types.py
Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
2024-11-26 10:41:53 -08:00
Jeffrey Morgan d8d98e17b2 Update ollama/_types.py
Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
2024-11-26 10:41:50 -08:00
Jeffrey Morgan ec2c8fdd8d Update ollama/_types.py
Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
2024-11-26 10:41:45 -08:00
Jeffrey Morgan ea0e0dc692 Update ollama/_types.py
Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
2024-11-26 10:35:57 -08:00
Shahil Yadav 6c44bb2729 Fix chat-with-history.py example (#337)
Fix chat-with-history.py example

---------

Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
2024-11-24 10:49:46 -08:00
jmorganca 2095fc9107 make subscription methods more consistent with maps 2024-11-23 19:02:28 -08:00
4 changed files with 57 additions and 10 deletions
+2 -2
View File
@@ -31,8 +31,8 @@ while True:
)
# Add the response to the messages to maintain the history
messages.append(
messages += [
{'role': 'user', 'content': user_input},
{'role': 'assistant', 'content': response.message.content},
)
]
print(response.message.content + '\n')
+4 -4
View File
@@ -263,7 +263,7 @@ class Client(BaseClient):
model: str = '',
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[False] = False,
format: Optional[Literal['', 'json']] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
@@ -276,7 +276,7 @@ class Client(BaseClient):
model: str = '',
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[True] = True,
format: Optional[Literal['', 'json']] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
@@ -765,7 +765,7 @@ class AsyncClient(BaseClient):
model: str = '',
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[False] = False,
format: Optional[Literal['', 'json']] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
@@ -790,7 +790,7 @@ class AsyncClient(BaseClient):
model: str = '',
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: bool = False,
format: Optional[Literal['', 'json']] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
+45 -4
View File
@@ -17,9 +17,32 @@ from pydantic import (
class SubscriptableBaseModel(BaseModel):
def __getitem__(self, key: str) -> Any:
return getattr(self, key)
"""
>>> msg = Message(role='user')
>>> msg['role']
'user'
>>> msg = Message(role='user')
>>> msg['nonexistent']
Traceback (most recent call last):
KeyError: 'nonexistent'
"""
if key in self:
return getattr(self, key)
raise KeyError(key)
def __setitem__(self, key: str, value: Any) -> None:
"""
>>> msg = Message(role='user')
>>> msg['role'] = 'assistant'
>>> msg['role']
'assistant'
>>> tool_call = Message.ToolCall(function=Message.ToolCall.Function(name='foo', arguments={}))
>>> msg = Message(role='user', content='hello')
>>> msg['tool_calls'] = [tool_call]
>>> msg['tool_calls'][0]['function']['name']
'foo'
"""
setattr(self, key, value)
def __contains__(self, key: str) -> bool:
@@ -61,7 +84,20 @@ class SubscriptableBaseModel(BaseModel):
return False
def get(self, key: str, default: Any = None) -> Any:
return getattr(self, key, default)
"""
>>> msg = Message(role='user')
>>> msg.get('role')
'user'
>>> msg = Message(role='user')
>>> msg.get('nonexistent')
>>> msg = Message(role='user')
>>> msg.get('nonexistent', 'default')
'default'
>>> msg = Message(role='user', tool_calls=[ Message.ToolCall(function=Message.ToolCall.Function(name='foo', arguments={}))])
>>> msg.get('tool_calls')[0]['function']['name']
'foo'
"""
return self[key] if key in self else default
class Options(SubscriptableBaseModel):
@@ -130,9 +166,14 @@ class Image(BaseModel):
return b64encode(self.value.read_bytes() if isinstance(self.value, Path) else self.value).decode()
if isinstance(self.value, str):
if Path(self.value).exists():
return b64encode(Path(self.value).read_bytes()).decode()
try:
if Path(self.value).exists():
return b64encode(Path(self.value).read_bytes()).decode()
except Exception:
# Long base64 string can't be wrapped in Path, so try to treat as base64 string
pass
# String might be a file path, but might not exist
if self.value.split('.')[-1] in ('png', 'jpg', 'jpeg', 'webp'):
raise ValueError(f'File {self.value} does not exist')
+6
View File
@@ -19,6 +19,12 @@ def test_image_serialization_base64_string():
assert img.model_dump() == b64_str # Should return as-is if valid base64
def test_image_serialization_long_base64_string():
b64_str = 'dGVzdCBiYXNlNjQgc3RyaW5n' * 1000
img = Image(value=b64_str)
assert img.model_dump() == b64_str # Should return as-is if valid base64
def test_image_serialization_plain_string():
img = Image(value='not a path or base64')
assert img.model_dump() == 'not a path or base64' # Should return as-is