mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
Fix/llm bugs empty extraction (#1533)
* Add llm singleton and check for empty extraction * Semver * Tests and spellcheck * Move the singletons to a proper place * Leftover print * Ruff
This commit is contained in:
parent
f7cd155dbc
commit
cfe2082669
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Manage llm instances inside a cached singleton. Check for empty dfs after entity/relationship extraction"
|
||||
}
|
||||
@ -52,6 +52,18 @@ async def extract_graph(
|
||||
num_threads=extraction_num_threads,
|
||||
)
|
||||
|
||||
if not _validate_data(entity_dfs):
|
||||
error_msg = "Entity Extraction failed. No entities detected during extraction."
|
||||
callbacks.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if not _validate_data(relationship_dfs):
|
||||
error_msg = (
|
||||
"Entity Extraction failed. No relationships detected during extraction."
|
||||
)
|
||||
callbacks.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
merged_entities = _merge_entities(entity_dfs)
|
||||
merged_relationships = _merge_relationships(relationship_dfs)
|
||||
|
||||
@ -145,3 +157,10 @@ def _compute_degree(graph: nx.Graph) -> pd.DataFrame:
|
||||
{"name": node, "degree": int(degree)}
|
||||
for node, degree in graph.degree # type: ignore
|
||||
])
|
||||
|
||||
|
||||
def _validate_data(df_list: list[pd.DataFrame]) -> bool:
|
||||
"""Validate that the dataframe list is valid. At least one dataframe must contain data."""
|
||||
return any(
|
||||
len(df) > 0 for df in df_list
|
||||
) # Check for len, not .empty, as the dfs have schemas in some cases
|
||||
|
||||
@ -24,6 +24,7 @@ from pydantic import TypeAdapter
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.enums import LLMType
|
||||
from graphrag.config.models.llm_parameters import LLMParameters
|
||||
from graphrag.index.llm.manager import ChatLLMSingleton, EmbeddingsLLMSingleton
|
||||
|
||||
from .mock_llm import MockChatLLM
|
||||
|
||||
@ -110,6 +111,10 @@ def load_llm(
|
||||
chat_only=False,
|
||||
) -> ChatLLM:
|
||||
"""Load the LLM for the entity extraction chain."""
|
||||
singleton_llm = ChatLLMSingleton().get_llm(name)
|
||||
if singleton_llm is not None:
|
||||
return singleton_llm
|
||||
|
||||
on_error = _create_error_handler(callbacks)
|
||||
llm_type = config.type
|
||||
|
||||
@ -119,7 +124,9 @@ def load_llm(
|
||||
raise ValueError(msg)
|
||||
|
||||
loader = loaders[llm_type]
|
||||
return loader["load"](on_error, create_cache(cache, name), config)
|
||||
llm_instance = loader["load"](on_error, create_cache(cache, name), config)
|
||||
ChatLLMSingleton().set_llm(name, llm_instance)
|
||||
return llm_instance
|
||||
|
||||
msg = f"Unknown LLM type {llm_type}"
|
||||
raise ValueError(msg)
|
||||
@ -134,15 +141,21 @@ def load_llm_embeddings(
|
||||
chat_only=False,
|
||||
) -> EmbeddingsLLM:
|
||||
"""Load the LLM for the entity extraction chain."""
|
||||
singleton_llm = EmbeddingsLLMSingleton().get_llm(name)
|
||||
if singleton_llm is not None:
|
||||
return singleton_llm
|
||||
|
||||
on_error = _create_error_handler(callbacks)
|
||||
llm_type = llm_config.type
|
||||
if llm_type in loaders:
|
||||
if chat_only and not loaders[llm_type]["chat"]:
|
||||
msg = f"LLM type {llm_type} does not support chat"
|
||||
raise ValueError(msg)
|
||||
return loaders[llm_type]["load"](
|
||||
llm_instance = loaders[llm_type]["load"](
|
||||
on_error, create_cache(cache, name), llm_config or {}
|
||||
)
|
||||
EmbeddingsLLMSingleton().set_llm(name, llm_instance)
|
||||
return llm_instance
|
||||
|
||||
msg = f"Unknown LLM type {llm_type}"
|
||||
raise ValueError(msg)
|
||||
@ -198,6 +211,7 @@ def _create_openai_config(config: LLMParameters, azure: bool) -> OpenAIConfig:
|
||||
n=config.n,
|
||||
temperature=config.temperature,
|
||||
)
|
||||
|
||||
if azure:
|
||||
if config.api_base is None:
|
||||
msg = "Azure OpenAI Chat LLM requires an API base"
|
||||
|
||||
40
graphrag/index/llm/manager.py
Normal file
40
graphrag/index/llm/manager.py
Normal file
@ -0,0 +1,40 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LLM Manager singleton."""
|
||||
|
||||
from functools import cache
|
||||
|
||||
from fnllm import ChatLLM, EmbeddingsLLM
|
||||
|
||||
|
||||
@cache
|
||||
class ChatLLMSingleton:
|
||||
"""A singleton class for the chat LLM instances."""
|
||||
|
||||
def __init__(self):
|
||||
self.llm_dict = {}
|
||||
|
||||
def set_llm(self, name, llm):
|
||||
"""Add an LLM to the dictionary."""
|
||||
self.llm_dict[name] = llm
|
||||
|
||||
def get_llm(self, name) -> ChatLLM | None:
|
||||
"""Get an LLM from the dictionary."""
|
||||
return self.llm_dict.get(name)
|
||||
|
||||
|
||||
@cache
|
||||
class EmbeddingsLLMSingleton:
|
||||
"""A singleton class for the embeddings LLM instances."""
|
||||
|
||||
def __init__(self):
|
||||
self.llm_dict = {}
|
||||
|
||||
def set_llm(self, name, llm):
|
||||
"""Add an LLM to the dictionary."""
|
||||
self.llm_dict[name] = llm
|
||||
|
||||
def get_llm(self, name) -> EmbeddingsLLM | None:
|
||||
"""Get an LLM from the dictionary."""
|
||||
return self.llm_dict.get(name)
|
||||
Loading…
Reference in New Issue
Block a user