diff --git a/.semversioner/next-release/patch-20250522003454958473.json b/.semversioner/next-release/patch-20250522003454958473.json new file mode 100644 index 00000000..7a7feaf0 --- /dev/null +++ b/.semversioner/next-release/patch-20250522003454958473.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Add full llm response to LLM PRovider output" +} diff --git a/graphrag/language_model/providers/fnllm/models.py b/graphrag/language_model/providers/fnllm/models.py index fda91c96..059b0412 100644 --- a/graphrag/language_model/providers/fnllm/models.py +++ b/graphrag/language_model/providers/fnllm/models.py @@ -83,7 +83,10 @@ class OpenAIChatFNLLM: else: response = await self.model(prompt, history=history, **kwargs) return BaseModelResponse( - output=BaseModelOutput(content=response.output.content), + output=BaseModelOutput( + content=response.output.content, + full_response=response.output.raw_model.to_dict(), + ), parsed_response=response.parsed_json, history=response.history, cache_hit=response.cache_hit, @@ -282,7 +285,10 @@ class AzureOpenAIChatFNLLM: else: response = await self.model(prompt, history=history, **kwargs) return BaseModelResponse( - output=BaseModelOutput(content=response.output.content), + output=BaseModelOutput( + content=response.output.content, + full_response=response.output.raw_model.to_dict(), + ), parsed_response=response.parsed_json, history=response.history, cache_hit=response.cache_hit, diff --git a/graphrag/language_model/response/base.py b/graphrag/language_model/response/base.py index abc17dda..178259c4 100644 --- a/graphrag/language_model/response/base.py +++ b/graphrag/language_model/response/base.py @@ -18,6 +18,11 @@ class ModelOutput(Protocol): """Return the textual content of the output.""" ... + @property + def full_response(self) -> dict[str, Any] | None: + """Return the complete JSON response returned by the model.""" + ... + class ModelResponse(Protocol, Generic[T]): """Protocol for LLM response.""" @@ -43,6 +48,10 @@ class BaseModelOutput(BaseModel): content: str = Field(..., description="The textual content of the output.") """The textual content of the output.""" + full_response: dict[str, Any] | None = Field( + None, description="The complete JSON response returned by the LLM provider." + ) + """The complete JSON response returned by the LLM provider.""" class BaseModelResponse(BaseModel, Generic[T]): diff --git a/graphrag/language_model/response/base.pyi b/graphrag/language_model/response/base.pyi new file mode 100644 index 00000000..7a33b0a3 --- /dev/null +++ b/graphrag/language_model/response/base.pyi @@ -0,0 +1,50 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT License + +from typing import Any, Generic, Protocol, TypeVar + +from pydantic import BaseModel + +_T = TypeVar("_T", bound=BaseModel, covariant=True) + +class ModelOutput(Protocol): + @property + def content(self) -> str: ... + @property + def full_response(self) -> dict[str, Any] | None: ... + +class ModelResponse(Protocol, Generic[_T]): + @property + def output(self) -> ModelOutput: ... + @property + def parsed_response(self) -> _T | None: ... + @property + def history(self) -> list[Any]: ... + +class BaseModelOutput(BaseModel): + content: str + full_response: dict[str, Any] | None + + def __init__( + self, + content: str, + full_response: dict[str, Any] | None = None, + ) -> None: ... + +class BaseModelResponse(BaseModel, Generic[_T]): + output: BaseModelOutput + parsed_response: _T | None + history: list[Any] + tool_calls: list[Any] + metrics: Any | None + cache_hit: bool | None + + def __init__( + self, + output: BaseModelOutput, + parsed_response: _T | None = None, + history: list[Any] = ..., # default provided by Pydantic + tool_calls: list[Any] = ..., # default provided by Pydantic + metrics: Any | None = None, + cache_hit: bool | None = None, + ) -> None: ... diff --git a/tests/integration/language_model/test_factory.py b/tests/integration/language_model/test_factory.py index e25e4e24..af503265 100644 --- a/tests/integration/language_model/test_factory.py +++ b/tests/integration/language_model/test_factory.py @@ -33,7 +33,11 @@ async def test_create_custom_chat_model(): def chat( self, prompt: str, history: list | None = None, **kwargs: Any ) -> ModelResponse: - return BaseModelResponse(output=BaseModelOutput(content="content")) + return BaseModelResponse( + output=BaseModelOutput( + content="content", full_response={"key": "value"} + ) + ) async def achat_stream( self, prompt: str, history: list | None = None, **kwargs: Any @@ -49,9 +53,11 @@ async def test_create_custom_chat_model(): assert isinstance(model, CustomChatModel) response = await model.achat("prompt") assert response.output.content == "content" + assert response.output.full_response is None response = model.chat("prompt") assert response.output.content == "content" + assert response.output.full_response == {"key": "value"} async def test_create_custom_embedding_llm():