From 2095fc91076bb9b68acc48b5958c6d82b863ea1b Mon Sep 17 00:00:00 2001 From: jmorganca Date: Sat, 23 Nov 2024 19:02:28 -0800 Subject: [PATCH] make subscription methods more consistent with maps --- ollama/_types.py | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/ollama/_types.py b/ollama/_types.py index 5be4850..2c9f3cb 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -17,9 +17,30 @@ from pydantic import ( class SubscriptableBaseModel(BaseModel): def __getitem__(self, key: str) -> Any: - return getattr(self, key) + """ + >>> msg = Message(role='user') + >>> msg['role'] + 'user' + >>> tool = Tool() + >>> tool['type'] + 'function' + >>> msg = Message(role='user') + >>> msg['nonexistent'] + Traceback (most recent call last): + KeyError: 'nonexistent' + """ + if key in self: + return getattr(self, key) + + raise KeyError(key) def __setitem__(self, key: str, value: Any) -> None: + """ + >>> msg = Message(role='user') + >>> msg['role'] = 'assistant' + >>> msg['role'] + 'assistant' + """ setattr(self, key, value) def __contains__(self, key: str) -> bool: @@ -61,7 +82,20 @@ class SubscriptableBaseModel(BaseModel): return False def get(self, key: str, default: Any = None) -> Any: - return getattr(self, key, default) + """ + >>> msg = Message(role='user') + >>> msg.get('role') + 'user' + >>> tool = Tool() + >>> tool.get('type') + 'function' + >>> msg = Message(role='user') + >>> msg.get('nonexistent') + >>> msg = Message(role='user') + >>> msg.get('nonexistent', 'default') + 'default' + """ + return self[key] if key in self else default class Options(SubscriptableBaseModel):