Fix/llm bugs empty extraction (#1533)

* Add llm singleton and check for empty extraction

* Semver

* Tests and spellcheck

* Move the singletons to a proper place

* Leftover print

* Ruff
This commit is contained in:
Alonso Guevara 2024-12-18 17:07:29 -06:00 committed by GitHub
parent f7cd155dbc
commit cfe2082669
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 79 additions and 2 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Manage llm instances inside a cached singleton. Check for empty dfs after entity/relationship extraction"
}

View File

@ -52,6 +52,18 @@ async def extract_graph(
num_threads=extraction_num_threads,
)
if not _validate_data(entity_dfs):
error_msg = "Entity Extraction failed. No entities detected during extraction."
callbacks.error(error_msg)
raise ValueError(error_msg)
if not _validate_data(relationship_dfs):
error_msg = (
"Entity Extraction failed. No relationships detected during extraction."
)
callbacks.error(error_msg)
raise ValueError(error_msg)
merged_entities = _merge_entities(entity_dfs)
merged_relationships = _merge_relationships(relationship_dfs)
@ -145,3 +157,10 @@ def _compute_degree(graph: nx.Graph) -> pd.DataFrame:
{"name": node, "degree": int(degree)}
for node, degree in graph.degree # type: ignore
])
def _validate_data(df_list: list[pd.DataFrame]) -> bool:
"""Validate that the dataframe list is valid. At least one dataframe must contain data."""
return any(
len(df) > 0 for df in df_list
) # Check for len, not .empty, as the dfs have schemas in some cases

View File

@ -24,6 +24,7 @@ from pydantic import TypeAdapter
import graphrag.config.defaults as defs
from graphrag.config.enums import LLMType
from graphrag.config.models.llm_parameters import LLMParameters
from graphrag.index.llm.manager import ChatLLMSingleton, EmbeddingsLLMSingleton
from .mock_llm import MockChatLLM
@ -110,6 +111,10 @@ def load_llm(
chat_only=False,
) -> ChatLLM:
"""Load the LLM for the entity extraction chain."""
singleton_llm = ChatLLMSingleton().get_llm(name)
if singleton_llm is not None:
return singleton_llm
on_error = _create_error_handler(callbacks)
llm_type = config.type
@ -119,7 +124,9 @@ def load_llm(
raise ValueError(msg)
loader = loaders[llm_type]
return loader["load"](on_error, create_cache(cache, name), config)
llm_instance = loader["load"](on_error, create_cache(cache, name), config)
ChatLLMSingleton().set_llm(name, llm_instance)
return llm_instance
msg = f"Unknown LLM type {llm_type}"
raise ValueError(msg)
@ -134,15 +141,21 @@ def load_llm_embeddings(
chat_only=False,
) -> EmbeddingsLLM:
"""Load the LLM for the entity extraction chain."""
singleton_llm = EmbeddingsLLMSingleton().get_llm(name)
if singleton_llm is not None:
return singleton_llm
on_error = _create_error_handler(callbacks)
llm_type = llm_config.type
if llm_type in loaders:
if chat_only and not loaders[llm_type]["chat"]:
msg = f"LLM type {llm_type} does not support chat"
raise ValueError(msg)
return loaders[llm_type]["load"](
llm_instance = loaders[llm_type]["load"](
on_error, create_cache(cache, name), llm_config or {}
)
EmbeddingsLLMSingleton().set_llm(name, llm_instance)
return llm_instance
msg = f"Unknown LLM type {llm_type}"
raise ValueError(msg)
@ -198,6 +211,7 @@ def _create_openai_config(config: LLMParameters, azure: bool) -> OpenAIConfig:
n=config.n,
temperature=config.temperature,
)
if azure:
if config.api_base is None:
msg = "Azure OpenAI Chat LLM requires an API base"

View File

@ -0,0 +1,40 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""LLM Manager singleton."""
from functools import cache
from fnllm import ChatLLM, EmbeddingsLLM
@cache
class ChatLLMSingleton:
"""A singleton class for the chat LLM instances."""
def __init__(self):
self.llm_dict = {}
def set_llm(self, name, llm):
"""Add an LLM to the dictionary."""
self.llm_dict[name] = llm
def get_llm(self, name) -> ChatLLM | None:
"""Get an LLM from the dictionary."""
return self.llm_dict.get(name)
@cache
class EmbeddingsLLMSingleton:
"""A singleton class for the embeddings LLM instances."""
def __init__(self):
self.llm_dict = {}
def set_llm(self, name, llm):
"""Add an LLM to the dictionary."""
self.llm_dict[name] = llm
def get_llm(self, name) -> EmbeddingsLLM | None:
"""Get an LLM from the dictionary."""
return self.llm_dict.get(name)