graphrag/tests/integration/language_model/test_factory.py
Nathan Evans 1bb9fa8e13
Some checks are pending
Python CI / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.11) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.10) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.11) (push) Waiting to run
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Unified factory (#2105)
* Simplify Factory interface

* Migrate CacheFactory to standard base class

* Migrate LoggerFactory to standard base class

* Migrate StorageFactory to standard base class

* Migrate VectorStoreFactory to standard base class

* Update vector store example notebook

* Delete notebook outputs

* Move default providers into factories

* Move retry/limit tests into integ

* Split language model factories

* Set smoke test tpm/rpm

* Fix factory integ tests

* Add method to smoke test, switch text to 'fast'

* Fix text smoke config for fast workflow

* Add new workflows to text smoke test

* Convert input readers to a proper factory

* Remove covariates from fast smoke test

* Update docs for input factory

* Bump smoke runtime

* Even longer runtime

* min-csv timeout

* Remove unnecessary lambdas
2025-10-20 12:05:27 -07:00

98 lines
3.0 KiB
Python

# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LLMFactory Tests.
These tests will test the LLMFactory class and the creation of custom and provided LLMs.
"""
from collections.abc import AsyncGenerator, Generator
from typing import Any
from graphrag.language_model.factory import ChatModelFactory, EmbeddingModelFactory
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.response.base import (
BaseModelOutput,
BaseModelResponse,
ModelResponse,
)
async def test_create_custom_chat_model():
class CustomChatModel:
config: Any
def __init__(self, **kwargs):
pass
async def achat(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> ModelResponse:
return BaseModelResponse(output=BaseModelOutput(content="content"))
def chat(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> ModelResponse:
return BaseModelResponse(
output=BaseModelOutput(
content="content", full_response={"key": "value"}
)
)
async def achat_stream(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> AsyncGenerator[str, None]:
yield ""
def chat_stream(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> Generator[str, None]: ...
ChatModelFactory().register("custom_chat", CustomChatModel)
model = ModelManager().get_or_create_chat_model("custom", "custom_chat")
assert isinstance(model, CustomChatModel)
response = await model.achat("prompt")
assert response.output.content == "content"
assert response.output.full_response is None
response = model.chat("prompt")
assert response.output.content == "content"
assert response.output.full_response == {"key": "value"}
async def test_create_custom_embedding_llm():
class CustomEmbeddingModel:
config: Any
def __init__(self, **kwargs):
pass
async def aembed(self, text: str, **kwargs) -> list[float]:
return [1.0]
def embed(self, text: str, **kwargs) -> list[float]:
return [1.0]
async def aembed_batch(
self, text_list: list[str], **kwargs
) -> list[list[float]]:
return [[1.0]]
def embed_batch(self, text_list: list[str], **kwargs) -> list[list[float]]:
return [[1.0]]
EmbeddingModelFactory().register("custom_embedding", CustomEmbeddingModel)
llm = ModelManager().get_or_create_embedding_model("custom", "custom_embedding")
assert isinstance(llm, CustomEmbeddingModel)
response = await llm.aembed("text")
assert response == [1.0]
response = llm.embed("text")
assert response == [1.0]
response = await llm.aembed_batch(["text"])
assert response == [[1.0]]
response = llm.embed_batch(["text"])
assert response == [[1.0]]