From 07eec6d517b95ef49fda672d8a533cb01a56b2fa Mon Sep 17 00:00:00 2001 From: rylativity <41017744+rylativity@users.noreply.github.com> Date: Thu, 20 Mar 2025 16:46:40 -0400 Subject: [PATCH] types: enable passing messages with arbitrary role (#462) --------- Co-authored-by: Ryan Stewart Co-authored-by: Gabe Goodhart --- ollama/_types.py | 2 +- tests/test_client.py | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/ollama/_types.py b/ollama/_types.py index 710c536..0df0ddb 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -256,7 +256,7 @@ class Message(SubscriptableBaseModel): Chat message. """ - role: Literal['user', 'assistant', 'system', 'tool'] + role: str "Assumed role of the message. Response messages has role 'assistant' or 'tool'." content: Optional[str] = None diff --git a/tests/test_client.py b/tests/test_client.py index 8890afd..ca29806 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,14 +4,16 @@ import os import re import tempfile from pathlib import Path +from typing import Any import pytest +from httpx import Response as httpxResponse from pydantic import BaseModel, ValidationError from pytest_httpserver import HTTPServer, URIPattern from werkzeug.wrappers import Request, Response from ollama._client import CONNECTION_ERROR_MESSAGE, AsyncClient, Client, _copy_tools -from ollama._types import Image +from ollama._types import Image, Message PNG_BASE64 = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC' PNG_BYTES = base64.b64decode(PNG_BASE64) @@ -1181,3 +1183,32 @@ async def test_async_client_connection_error(): with pytest.raises(ConnectionError) as exc_info: await client.show('model') assert str(exc_info.value) == 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download' + + +def test_arbitrary_roles_accepted_in_message(): + _ = Message(role='somerandomrole', content="I'm ok with you adding any role message now!") + + +def _mock_request(*args: Any, **kwargs: Any) -> Response: + return httpxResponse(status_code=200, content="{'response': 'Hello world!'}") + + +def test_arbitrary_roles_accepted_in_message_request(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(Client, '_request', _mock_request) + + client = Client() + + client.chat(model='llama3.1', messages=[{'role': 'somerandomrole', 'content': "I'm ok with you adding any role message now!"}, {'role': 'user', 'content': 'Hello world!'}]) + + +async def _mock_request_async(*args: Any, **kwargs: Any) -> Response: + return httpxResponse(status_code=200, content="{'response': 'Hello world!'}") + + +@pytest.mark.asyncio +async def test_arbitrary_roles_accepted_in_message_request_async(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(AsyncClient, '_request', _mock_request_async) + + client = AsyncClient() + + await client.chat(model='llama3.1', messages=[{'role': 'somerandomrole', 'content': "I'm ok with you adding any role message now!"}, {'role': 'user', 'content': 'Hello world!'}])