graphrag/tests/integration/language_model/test_factory.py
Nathan Evans ad4cdd685f
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
Support OpenAI reasoning models (#1841)
* Update tiktoken

* Add max_completion_tokens to model config

* Update/remove outdated comments

* Remove max_tokens from report generation

* Remove max_tokens from entity summarization

* Remove logit_bias from graph extraction

* Remove logit_bias from claim extraction

* Swap params if reasoning model

* Add reasoning model support to basic search

* Add reasoning model support for local and global search

* Support reasoning models with dynamic community selection

* Support reasoning models in DRIFT search

* Remove unused num_threads entry

* Semver

* Update openai

* Add reasoning_effort param
2025-04-22 14:15:26 -07:00

92 lines
2.8 KiB
Python

# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LLMFactory Tests.
These tests will test the LLMFactory class and the creation of custom and provided LLMs.
"""
from collections.abc import AsyncGenerator, Generator
from typing import Any
from graphrag.language_model.factory import ModelFactory
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.response.base import (
BaseModelOutput,
BaseModelResponse,
ModelResponse,
)
async def test_create_custom_chat_model():
class CustomChatModel:
config: Any
def __init__(self, **kwargs):
pass
async def achat(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> ModelResponse:
return BaseModelResponse(output=BaseModelOutput(content="content"))
def chat(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> ModelResponse:
return BaseModelResponse(output=BaseModelOutput(content="content"))
async def achat_stream(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> AsyncGenerator[str, None]:
yield ""
def chat_stream(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> Generator[str, None]: ...
ModelFactory.register_chat("custom_chat", CustomChatModel)
model = ModelManager().get_or_create_chat_model("custom", "custom_chat")
assert isinstance(model, CustomChatModel)
response = await model.achat("prompt")
assert response.output.content == "content"
response = model.chat("prompt")
assert response.output.content == "content"
async def test_create_custom_embedding_llm():
class CustomEmbeddingModel:
config: Any
def __init__(self, **kwargs):
pass
async def aembed(self, text: str, **kwargs) -> list[float]:
return [1.0]
def embed(self, text: str, **kwargs) -> list[float]:
return [1.0]
async def aembed_batch(
self, text_list: list[str], **kwargs
) -> list[list[float]]:
return [[1.0]]
def embed_batch(self, text_list: list[str], **kwargs) -> list[list[float]]:
return [[1.0]]
ModelFactory.register_embedding("custom_embedding", CustomEmbeddingModel)
llm = ModelManager().get_or_create_embedding_model("custom", "custom_embedding")
assert isinstance(llm, CustomEmbeddingModel)
response = await llm.aembed("text")
assert response == [1.0]
response = llm.embed("text")
assert response == [1.0]
response = await llm.aembed_batch(["text"])
assert response == [[1.0]]
response = llm.embed_batch(["text"])
assert response == [[1.0]]