mirror of
https://github.com/ollama/ollama-python.git
synced 2026-01-13 21:57:16 +08:00
make subscription methods more consistent with maps
This commit is contained in:
parent
64e3723e6b
commit
2095fc9107
@ -17,9 +17,30 @@ from pydantic import (
|
||||
|
||||
class SubscriptableBaseModel(BaseModel):
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return getattr(self, key)
|
||||
"""
|
||||
>>> msg = Message(role='user')
|
||||
>>> msg['role']
|
||||
'user'
|
||||
>>> tool = Tool()
|
||||
>>> tool['type']
|
||||
'function'
|
||||
>>> msg = Message(role='user')
|
||||
>>> msg['nonexistent']
|
||||
Traceback (most recent call last):
|
||||
KeyError: 'nonexistent'
|
||||
"""
|
||||
if key in self:
|
||||
return getattr(self, key)
|
||||
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
"""
|
||||
>>> msg = Message(role='user')
|
||||
>>> msg['role'] = 'assistant'
|
||||
>>> msg['role']
|
||||
'assistant'
|
||||
"""
|
||||
setattr(self, key, value)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
@ -61,7 +82,20 @@ class SubscriptableBaseModel(BaseModel):
|
||||
return False
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return getattr(self, key, default)
|
||||
"""
|
||||
>>> msg = Message(role='user')
|
||||
>>> msg.get('role')
|
||||
'user'
|
||||
>>> tool = Tool()
|
||||
>>> tool.get('type')
|
||||
'function'
|
||||
>>> msg = Message(role='user')
|
||||
>>> msg.get('nonexistent')
|
||||
>>> msg = Message(role='user')
|
||||
>>> msg.get('nonexistent', 'default')
|
||||
'default'
|
||||
"""
|
||||
return self[key] if key in self else default
|
||||
|
||||
|
||||
class Options(SubscriptableBaseModel):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user