Merge remote-tracking branch 'origin/main' into feat/metadata

This commit is contained in:
Dayenne Souza 2025-01-30 17:12:14 -03:00
commit 3dcdd3ce53
18 changed files with 122 additions and 75 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Require explicit azure auth settings when using AOI."
}

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fix missing embeddings workflow in FastGraphRAG."
}

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fix report generation recursion."
}

View File

@ -7,6 +7,7 @@ from pathlib import Path
from graphrag.config.enums import (
AsyncType,
AuthType,
CacheType,
ChunkStrategyType,
InputFileType,
@ -24,6 +25,7 @@ DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model"
ASYNC_MODE = AsyncType.Threaded
ENCODING_MODEL = "cl100k_base"
AZURE_AUDIENCE = "https://cognitiveservices.azure.com/.default"
AUTH_TYPE = AuthType.APIKey
#
# LLM Parameters
#

View File

@ -117,11 +117,11 @@ class LLMType(str, Enum):
return f'"{self.value}"'
class AzureAuthType(str, Enum):
"""AzureAuthType enum class definition."""
class AuthType(str, Enum):
"""AuthType enum class definition."""
APIKey = "api_key"
ManagedIdentity = "managed_identity"
AzureManagedIdentity = "azure_managed_identity"
class AsyncType(str, Enum):

View File

@ -16,6 +16,7 @@ models:
{defs.DEFAULT_CHAT_MODEL_ID}:
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
type: {defs.LLM_TYPE.value} # or azure_openai_chat
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
model: {defs.LLM_MODEL}
model_supports_json: true # recommended if this is available for your model.
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
@ -29,6 +30,7 @@ models:
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
api_key: ${{GRAPHRAG_API_KEY}}
type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
model: {defs.EMBEDDING_MODEL}
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
parallelization_stagger: {defs.PARALLELIZATION_STAGGER}

View File

@ -7,7 +7,7 @@ import tiktoken
from pydantic import BaseModel, Field, model_validator
import graphrag.config.defaults as defs
from graphrag.config.enums import AsyncType, AzureAuthType, LLMType
from graphrag.config.enums import AsyncType, AuthType, LLMType
from graphrag.config.errors import (
ApiKeyMissingError,
AzureApiBaseMissingError,
@ -40,27 +40,42 @@ class LanguageModelConfig(BaseModel):
ApiKeyMissingError
If the API key is missing and is required.
"""
if (
self.type == LLMType.OpenAIEmbedding
or self.type == LLMType.OpenAIChat
or self.azure_auth_type == AzureAuthType.APIKey
) and (self.api_key is None or self.api_key.strip() == ""):
if self.auth_type == AuthType.APIKey and (
self.api_key is None or self.api_key.strip() == ""
):
raise ApiKeyMissingError(
self.type.value,
self.azure_auth_type.value if self.azure_auth_type else None,
self.auth_type.value,
)
if (self.azure_auth_type == AzureAuthType.ManagedIdentity) and (
if (self.auth_type == AuthType.AzureManagedIdentity) and (
self.api_key is not None and self.api_key.strip() != ""
):
msg = "API Key should not be provided when using Azure Managed Identity. Please rerun `graphrag init` and remove the api_key when using Azure Managed Identity."
raise ConflictingSettingsError(msg)
azure_auth_type: AzureAuthType | None = Field(
description="The Azure authentication type to use when using AOI.",
default=None,
auth_type: AuthType = Field(
description="The authentication type.",
default=defs.AUTH_TYPE,
)
def _validate_auth_type(self) -> None:
"""Validate the authentication type.
auth_type must be api_key when using OpenAI and
can be either api_key or azure_managed_identity when using AOI.
Raises
------
ConflictingSettingsError
If the Azure authentication type conflicts with the model being used.
"""
if self.auth_type == AuthType.AzureManagedIdentity and (
self.type == LLMType.OpenAIChat or self.type == LLMType.OpenAIEmbedding
):
msg = f"auth_type of azure_managed_identity is not supported for model type {self.type.value}. Please rerun `graphrag init` and set the auth_type to api_key."
raise ConflictingSettingsError(msg)
type: LLMType = Field(description="The type of LLM model to use.")
model: str = Field(description="The LLM model to use.")
encoding_model: str = Field(description="The encoding model to use", default="")
@ -233,6 +248,7 @@ class LanguageModelConfig(BaseModel):
@model_validator(mode="after")
def _validate_model(self):
self._validate_auth_type()
self._validate_api_key()
self._validate_azure_settings()
self._validate_encoding_model()

View File

@ -13,10 +13,9 @@ from graphrag.config import defaults
from graphrag.config.enums import AsyncType
from graphrag.index.operations.summarize_communities import (
prepare_community_reports,
restore_community_hierarchy,
summarize_communities,
)
from graphrag.index.operations.summarize_communities.community_reports_extractor import (
from graphrag.index.operations.summarize_communities.community_reports_extractor.prep_community_report_context import (
prep_community_report_context,
)
from graphrag.index.operations.summarize_communities.community_reports_extractor.schemas import (
@ -39,9 +38,6 @@ from graphrag.index.operations.summarize_communities.community_reports_extractor
NODE_ID,
NODE_NAME,
)
from graphrag.index.operations.summarize_communities.community_reports_extractor.utils import (
get_levels,
)
async def create_final_community_reports(
@ -66,35 +62,26 @@ async def create_final_community_reports(
if claims_input is not None:
claims = _prep_claims(claims_input)
max_input_length = summarization_strategy.get(
"max_input_length", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
)
local_contexts = prepare_community_reports(
nodes,
edges,
claims,
callbacks,
summarization_strategy.get("max_input_length", 16_000),
max_input_length,
)
community_hierarchy = restore_community_hierarchy(nodes)
levels = get_levels(nodes)
level_contexts = []
for level in levels:
level_context = prep_community_report_context(
local_context_df=local_contexts,
community_hierarchy_df=community_hierarchy,
level=level,
max_tokens=summarization_strategy.get(
"max_input_tokens", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
),
)
level_contexts.append(level_context)
community_reports = await summarize_communities(
nodes,
local_contexts,
level_contexts,
prep_community_report_context,
callbacks,
cache,
summarization_strategy,
max_input_length=max_input_length,
async_mode=async_mode,
num_threads=num_threads,
)

View File

@ -13,12 +13,8 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config import defaults
from graphrag.config.enums import AsyncType
from graphrag.index.operations.summarize_communities import (
restore_community_hierarchy,
summarize_communities,
)
from graphrag.index.operations.summarize_communities.community_reports_extractor.utils import (
get_levels,
)
from graphrag.index.operations.summarize_communities_text.context_builder import (
prep_community_report_context,
prep_local_context,
@ -46,36 +42,25 @@ async def create_final_community_reports_text(
nodes_df = nodes_input.merge(entities_df, on="id")
nodes = nodes_df.loc[nodes_df.loc[:, "community"] != -1]
max_input_length = summarization_strategy.get("max_input_length", 16_000)
# TEMP: forcing override of the prompt until we can put it into config
summarization_strategy["extraction_prompt"] = COMMUNITY_REPORT_PROMPT
# build initial local context for all communities
max_input_length = summarization_strategy.get(
"max_input_length", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
)
local_contexts = prep_local_context(
communities, text_units, nodes, max_input_length
)
community_hierarchy = restore_community_hierarchy(nodes)
levels = get_levels(nodes)
level_contexts = []
for level in levels:
level_context = prep_community_report_context(
local_context_df=local_contexts,
community_hierarchy_df=community_hierarchy,
level=level,
max_tokens=summarization_strategy.get(
"max_input_tokens", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
),
)
level_contexts.append(level_context)
community_reports = await summarize_communities(
nodes,
local_contexts,
level_contexts,
prep_community_report_context,
callbacks,
cache,
summarization_strategy,
max_input_length=max_input_length,
async_mode=async_mode,
num_threads=num_threads,
)

View File

@ -30,6 +30,7 @@ log = logging.getLogger(__name__)
def prep_community_report_context(
report_df: pd.DataFrame | None,
community_hierarchy_df: pd.DataFrame,
local_context_df: pd.DataFrame,
level: int,
@ -42,8 +43,6 @@ def prep_community_report_context(
- Check if local context fits within the limit, if yes use local context
- If local context exceeds the limit, iteratively replace local context with sub-community reports, starting from the biggest sub-community
"""
report_df = pd.DataFrame()
# Filter by community level
level_context_df = local_context_df.loc[
local_context_df.loc[:, schemas.COMMUNITY_LEVEL] == level
@ -62,7 +61,7 @@ def prep_community_report_context(
if invalid_context_df.empty:
return valid_context_df
if report_df.empty:
if report_df is None or report_df.empty:
invalid_context_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
invalid_context_df, max_tokens
)

View File

@ -4,6 +4,7 @@
"""A module containing create_community_reports and load_strategy methods definition."""
import logging
from collections.abc import Callable
import pandas as pd
@ -12,6 +13,12 @@ from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
from graphrag.index.operations.summarize_communities.community_reports_extractor.utils import (
get_levels,
)
from graphrag.index.operations.summarize_communities.restore_community_hierarchy import (
restore_community_hierarchy,
)
from graphrag.index.operations.summarize_communities.typing import (
CommunityReport,
CommunityReportsStrategy,
@ -24,11 +31,13 @@ log = logging.getLogger(__name__)
async def summarize_communities(
nodes: pd.DataFrame,
local_contexts,
level_contexts,
level_context_builder: Callable,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
strategy: dict,
max_input_length: int,
async_mode: AsyncType = AsyncType.AsyncIO,
num_threads: int = 4,
):
@ -37,6 +46,20 @@ async def summarize_communities(
tick = progress_ticker(callbacks.progress, len(local_contexts))
runner = load_strategy(strategy["type"])
community_hierarchy = restore_community_hierarchy(nodes)
levels = get_levels(nodes)
level_contexts = []
for level in levels:
level_context = level_context_builder(
pd.DataFrame(reports),
community_hierarchy_df=community_hierarchy,
local_context_df=local_contexts,
level=level,
max_tokens=max_input_length,
)
level_contexts.append(level_context)
for level_context in level_contexts:
async def run_generate(record):

View File

@ -76,10 +76,10 @@ def prep_local_context(
def prep_community_report_context(
local_context_df: pd.DataFrame,
report_df: pd.DataFrame | None,
community_hierarchy_df: pd.DataFrame,
local_context_df: pd.DataFrame,
level: int,
report_df: pd.DataFrame | None = None,
max_tokens: int = 16000,
) -> pd.DataFrame:
"""

View File

@ -56,4 +56,5 @@ def _get_workflows_list(
"create_final_communities",
"create_final_text_units",
"create_final_community_reports_text",
"generate_text_embeddings",
]

View File

@ -5,7 +5,7 @@
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from graphrag.config.enums import LLMType
from graphrag.config.enums import AuthType, LLMType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
@ -31,7 +31,8 @@ def get_llm(config: GraphRagConfig) -> ChatOpenAI:
api_key=default_llm_settings.api_key,
azure_ad_token_provider=(
get_bearer_token_provider(DefaultAzureCredential(), audience)
if is_azure_client and not default_llm_settings.api_key
if is_azure_client
and default_llm_settings.auth_type == AuthType.AzureManagedIdentity
else None
),
api_base=default_llm_settings.api_base,
@ -65,7 +66,8 @@ def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
api_key=embeddings_llm_settings.api_key,
azure_ad_token_provider=(
get_bearer_token_provider(DefaultAzureCredential(), audience)
if is_azure_client and not embeddings_llm_settings.api_key
if is_azure_client
and embeddings_llm_settings.auth_type == AuthType.AzureManagedIdentity
else None
),
api_base=embeddings_llm_settings.api_base,

View File

@ -1,5 +1,6 @@
models:
default_chat_model:
azure_auth_type: api_key
type: ${GRAPHRAG_LLM_TYPE}
api_key: ${GRAPHRAG_API_KEY}
api_base: ${GRAPHRAG_API_BASE}
@ -13,6 +14,7 @@ models:
parallelization_stagger: 0.3
async_mode: threaded
default_embedding_model:
azure_auth_type: api_key
type: ${GRAPHRAG_EMBEDDING_TYPE}
api_key: ${GRAPHRAG_API_KEY}
api_base: ${GRAPHRAG_API_BASE}

View File

@ -1,5 +1,6 @@
models:
default_chat_model:
azure_auth_type: api_key
type: ${GRAPHRAG_LLM_TYPE}
api_key: ${GRAPHRAG_API_KEY}
api_base: ${GRAPHRAG_API_BASE}
@ -13,6 +14,7 @@ models:
parallelization_stagger: 0.3
async_mode: threaded
default_embedding_model:
azure_auth_type: api_key
type: ${GRAPHRAG_EMBEDDING_TYPE}
api_key: ${GRAPHRAG_API_KEY}
api_base: ${GRAPHRAG_API_BASE}

View File

@ -10,7 +10,7 @@ from pydantic import ValidationError
import graphrag.config.defaults as defs
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.enums import AzureAuthType, LLMType
from graphrag.config.enums import AuthType, LLMType
from graphrag.config.load_config import load_config
from tests.unit.config.utils import (
DEFAULT_EMBEDDING_MODEL_CONFIG,
@ -46,7 +46,7 @@ def test_missing_azure_api_key() -> None:
model_config_missing_api_key = {
defs.DEFAULT_CHAT_MODEL_ID: {
"type": LLMType.AzureOpenAIChat,
"azure_auth_type": AzureAuthType.APIKey,
"auth_type": AuthType.APIKey,
"model": defs.LLM_MODEL,
"api_base": "some_api_base",
"api_version": "some_api_version",
@ -59,17 +59,31 @@ def test_missing_azure_api_key() -> None:
create_graphrag_config({"models": model_config_missing_api_key})
# API Key not required for managed identity
model_config_missing_api_key[defs.DEFAULT_CHAT_MODEL_ID]["azure_auth_type"] = (
AzureAuthType.ManagedIdentity
model_config_missing_api_key[defs.DEFAULT_CHAT_MODEL_ID]["auth_type"] = (
AuthType.AzureManagedIdentity
)
create_graphrag_config({"models": model_config_missing_api_key})
def test_conflicting_auth_type() -> None:
model_config_invalid_auth_type = {
defs.DEFAULT_CHAT_MODEL_ID: {
"auth_type": AuthType.AzureManagedIdentity,
"type": LLMType.OpenAIChat,
"model": defs.LLM_MODEL,
},
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
}
with pytest.raises(ValidationError):
create_graphrag_config({"models": model_config_invalid_auth_type})
def test_conflicting_azure_api_key() -> None:
model_config_conflicting_api_key = {
defs.DEFAULT_CHAT_MODEL_ID: {
"type": LLMType.AzureOpenAIChat,
"azure_auth_type": AzureAuthType.ManagedIdentity,
"auth_type": AuthType.AzureManagedIdentity,
"model": defs.LLM_MODEL,
"api_base": "some_api_base",
"api_version": "some_api_version",
@ -85,7 +99,7 @@ def test_conflicting_azure_api_key() -> None:
base_azure_model_config = {
"type": LLMType.AzureOpenAIChat,
"azure_auth_type": AzureAuthType.ManagedIdentity,
"auth_type": AuthType.AzureManagedIdentity,
"model": defs.LLM_MODEL,
"api_base": "some_api_base",
"api_version": "some_api_version",

View File

@ -241,7 +241,7 @@ def assert_language_model_configs(
actual: LanguageModelConfig, expected: LanguageModelConfig
) -> None:
assert actual.api_key == expected.api_key
assert actual.azure_auth_type == expected.azure_auth_type
assert actual.auth_type == expected.auth_type
assert actual.type == expected.type
assert actual.model == expected.model
assert actual.encoding_model == expected.encoding_model