Task/raw model answer (#1947)
Some checks are pending
gh-pages / build (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.11) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.10) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.11) (push) Waiting to run
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Publish (pypi) / Upload release to PyPI (push) Waiting to run
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Spellcheck / spellcheck (push) Waiting to run

* Add full_response to llm provider output

* Semver

* Small leftover cleanup

* Add pyi to suppress Pyright errors. full_content is optional

* Format

* Add missing stubs
This commit is contained in:
Alonso Guevara 2025-05-22 08:22:44 -06:00 committed by GitHub
parent fb4fe72a73
commit 7fba9522d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 78 additions and 3 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add full llm response to LLM PRovider output"
}

View File

@ -83,7 +83,10 @@ class OpenAIChatFNLLM:
else:
response = await self.model(prompt, history=history, **kwargs)
return BaseModelResponse(
output=BaseModelOutput(content=response.output.content),
output=BaseModelOutput(
content=response.output.content,
full_response=response.output.raw_model.to_dict(),
),
parsed_response=response.parsed_json,
history=response.history,
cache_hit=response.cache_hit,
@ -282,7 +285,10 @@ class AzureOpenAIChatFNLLM:
else:
response = await self.model(prompt, history=history, **kwargs)
return BaseModelResponse(
output=BaseModelOutput(content=response.output.content),
output=BaseModelOutput(
content=response.output.content,
full_response=response.output.raw_model.to_dict(),
),
parsed_response=response.parsed_json,
history=response.history,
cache_hit=response.cache_hit,

View File

@ -18,6 +18,11 @@ class ModelOutput(Protocol):
"""Return the textual content of the output."""
...
@property
def full_response(self) -> dict[str, Any] | None:
"""Return the complete JSON response returned by the model."""
...
class ModelResponse(Protocol, Generic[T]):
"""Protocol for LLM response."""
@ -43,6 +48,10 @@ class BaseModelOutput(BaseModel):
content: str = Field(..., description="The textual content of the output.")
"""The textual content of the output."""
full_response: dict[str, Any] | None = Field(
None, description="The complete JSON response returned by the LLM provider."
)
"""The complete JSON response returned by the LLM provider."""
class BaseModelResponse(BaseModel, Generic[T]):

View File

@ -0,0 +1,50 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
from typing import Any, Generic, Protocol, TypeVar
from pydantic import BaseModel
_T = TypeVar("_T", bound=BaseModel, covariant=True)
class ModelOutput(Protocol):
@property
def content(self) -> str: ...
@property
def full_response(self) -> dict[str, Any] | None: ...
class ModelResponse(Protocol, Generic[_T]):
@property
def output(self) -> ModelOutput: ...
@property
def parsed_response(self) -> _T | None: ...
@property
def history(self) -> list[Any]: ...
class BaseModelOutput(BaseModel):
content: str
full_response: dict[str, Any] | None
def __init__(
self,
content: str,
full_response: dict[str, Any] | None = None,
) -> None: ...
class BaseModelResponse(BaseModel, Generic[_T]):
output: BaseModelOutput
parsed_response: _T | None
history: list[Any]
tool_calls: list[Any]
metrics: Any | None
cache_hit: bool | None
def __init__(
self,
output: BaseModelOutput,
parsed_response: _T | None = None,
history: list[Any] = ..., # default provided by Pydantic
tool_calls: list[Any] = ..., # default provided by Pydantic
metrics: Any | None = None,
cache_hit: bool | None = None,
) -> None: ...

View File

@ -33,7 +33,11 @@ async def test_create_custom_chat_model():
def chat(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> ModelResponse:
return BaseModelResponse(output=BaseModelOutput(content="content"))
return BaseModelResponse(
output=BaseModelOutput(
content="content", full_response={"key": "value"}
)
)
async def achat_stream(
self, prompt: str, history: list | None = None, **kwargs: Any
@ -49,9 +53,11 @@ async def test_create_custom_chat_model():
assert isinstance(model, CustomChatModel)
response = await model.achat("prompt")
assert response.output.content == "content"
assert response.output.full_response is None
response = model.chat("prompt")
assert response.output.content == "content"
assert response.output.full_response == {"key": "value"}
async def test_create_custom_embedding_llm():