mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Merge remote-tracking branch 'origin/main' into feat/metadata
This commit is contained in:
commit
3dcdd3ce53
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Require explicit azure auth settings when using AOI."
|
||||
}
|
||||
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Fix missing embeddings workflow in FastGraphRAG."
|
||||
}
|
||||
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Fix report generation recursion."
|
||||
}
|
||||
@ -7,6 +7,7 @@ from pathlib import Path
|
||||
|
||||
from graphrag.config.enums import (
|
||||
AsyncType,
|
||||
AuthType,
|
||||
CacheType,
|
||||
ChunkStrategyType,
|
||||
InputFileType,
|
||||
@ -24,6 +25,7 @@ DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model"
|
||||
ASYNC_MODE = AsyncType.Threaded
|
||||
ENCODING_MODEL = "cl100k_base"
|
||||
AZURE_AUDIENCE = "https://cognitiveservices.azure.com/.default"
|
||||
AUTH_TYPE = AuthType.APIKey
|
||||
#
|
||||
# LLM Parameters
|
||||
#
|
||||
|
||||
@ -117,11 +117,11 @@ class LLMType(str, Enum):
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
class AzureAuthType(str, Enum):
|
||||
"""AzureAuthType enum class definition."""
|
||||
class AuthType(str, Enum):
|
||||
"""AuthType enum class definition."""
|
||||
|
||||
APIKey = "api_key"
|
||||
ManagedIdentity = "managed_identity"
|
||||
AzureManagedIdentity = "azure_managed_identity"
|
||||
|
||||
|
||||
class AsyncType(str, Enum):
|
||||
|
||||
@ -16,6 +16,7 @@ models:
|
||||
{defs.DEFAULT_CHAT_MODEL_ID}:
|
||||
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
|
||||
type: {defs.LLM_TYPE.value} # or azure_openai_chat
|
||||
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
|
||||
model: {defs.LLM_MODEL}
|
||||
model_supports_json: true # recommended if this is available for your model.
|
||||
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
|
||||
@ -29,6 +30,7 @@ models:
|
||||
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
|
||||
api_key: ${{GRAPHRAG_API_KEY}}
|
||||
type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding
|
||||
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
|
||||
model: {defs.EMBEDDING_MODEL}
|
||||
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
|
||||
parallelization_stagger: {defs.PARALLELIZATION_STAGGER}
|
||||
|
||||
@ -7,7 +7,7 @@ import tiktoken
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.enums import AsyncType, AzureAuthType, LLMType
|
||||
from graphrag.config.enums import AsyncType, AuthType, LLMType
|
||||
from graphrag.config.errors import (
|
||||
ApiKeyMissingError,
|
||||
AzureApiBaseMissingError,
|
||||
@ -40,27 +40,42 @@ class LanguageModelConfig(BaseModel):
|
||||
ApiKeyMissingError
|
||||
If the API key is missing and is required.
|
||||
"""
|
||||
if (
|
||||
self.type == LLMType.OpenAIEmbedding
|
||||
or self.type == LLMType.OpenAIChat
|
||||
or self.azure_auth_type == AzureAuthType.APIKey
|
||||
) and (self.api_key is None or self.api_key.strip() == ""):
|
||||
if self.auth_type == AuthType.APIKey and (
|
||||
self.api_key is None or self.api_key.strip() == ""
|
||||
):
|
||||
raise ApiKeyMissingError(
|
||||
self.type.value,
|
||||
self.azure_auth_type.value if self.azure_auth_type else None,
|
||||
self.auth_type.value,
|
||||
)
|
||||
|
||||
if (self.azure_auth_type == AzureAuthType.ManagedIdentity) and (
|
||||
if (self.auth_type == AuthType.AzureManagedIdentity) and (
|
||||
self.api_key is not None and self.api_key.strip() != ""
|
||||
):
|
||||
msg = "API Key should not be provided when using Azure Managed Identity. Please rerun `graphrag init` and remove the api_key when using Azure Managed Identity."
|
||||
raise ConflictingSettingsError(msg)
|
||||
|
||||
azure_auth_type: AzureAuthType | None = Field(
|
||||
description="The Azure authentication type to use when using AOI.",
|
||||
default=None,
|
||||
auth_type: AuthType = Field(
|
||||
description="The authentication type.",
|
||||
default=defs.AUTH_TYPE,
|
||||
)
|
||||
|
||||
def _validate_auth_type(self) -> None:
|
||||
"""Validate the authentication type.
|
||||
|
||||
auth_type must be api_key when using OpenAI and
|
||||
can be either api_key or azure_managed_identity when using AOI.
|
||||
|
||||
Raises
|
||||
------
|
||||
ConflictingSettingsError
|
||||
If the Azure authentication type conflicts with the model being used.
|
||||
"""
|
||||
if self.auth_type == AuthType.AzureManagedIdentity and (
|
||||
self.type == LLMType.OpenAIChat or self.type == LLMType.OpenAIEmbedding
|
||||
):
|
||||
msg = f"auth_type of azure_managed_identity is not supported for model type {self.type.value}. Please rerun `graphrag init` and set the auth_type to api_key."
|
||||
raise ConflictingSettingsError(msg)
|
||||
|
||||
type: LLMType = Field(description="The type of LLM model to use.")
|
||||
model: str = Field(description="The LLM model to use.")
|
||||
encoding_model: str = Field(description="The encoding model to use", default="")
|
||||
@ -233,6 +248,7 @@ class LanguageModelConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_model(self):
|
||||
self._validate_auth_type()
|
||||
self._validate_api_key()
|
||||
self._validate_azure_settings()
|
||||
self._validate_encoding_model()
|
||||
|
||||
@ -13,10 +13,9 @@ from graphrag.config import defaults
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.index.operations.summarize_communities import (
|
||||
prepare_community_reports,
|
||||
restore_community_hierarchy,
|
||||
summarize_communities,
|
||||
)
|
||||
from graphrag.index.operations.summarize_communities.community_reports_extractor import (
|
||||
from graphrag.index.operations.summarize_communities.community_reports_extractor.prep_community_report_context import (
|
||||
prep_community_report_context,
|
||||
)
|
||||
from graphrag.index.operations.summarize_communities.community_reports_extractor.schemas import (
|
||||
@ -39,9 +38,6 @@ from graphrag.index.operations.summarize_communities.community_reports_extractor
|
||||
NODE_ID,
|
||||
NODE_NAME,
|
||||
)
|
||||
from graphrag.index.operations.summarize_communities.community_reports_extractor.utils import (
|
||||
get_levels,
|
||||
)
|
||||
|
||||
|
||||
async def create_final_community_reports(
|
||||
@ -66,35 +62,26 @@ async def create_final_community_reports(
|
||||
if claims_input is not None:
|
||||
claims = _prep_claims(claims_input)
|
||||
|
||||
max_input_length = summarization_strategy.get(
|
||||
"max_input_length", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
|
||||
)
|
||||
|
||||
local_contexts = prepare_community_reports(
|
||||
nodes,
|
||||
edges,
|
||||
claims,
|
||||
callbacks,
|
||||
summarization_strategy.get("max_input_length", 16_000),
|
||||
max_input_length,
|
||||
)
|
||||
|
||||
community_hierarchy = restore_community_hierarchy(nodes)
|
||||
levels = get_levels(nodes)
|
||||
|
||||
level_contexts = []
|
||||
for level in levels:
|
||||
level_context = prep_community_report_context(
|
||||
local_context_df=local_contexts,
|
||||
community_hierarchy_df=community_hierarchy,
|
||||
level=level,
|
||||
max_tokens=summarization_strategy.get(
|
||||
"max_input_tokens", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
|
||||
),
|
||||
)
|
||||
level_contexts.append(level_context)
|
||||
|
||||
community_reports = await summarize_communities(
|
||||
nodes,
|
||||
local_contexts,
|
||||
level_contexts,
|
||||
prep_community_report_context,
|
||||
callbacks,
|
||||
cache,
|
||||
summarization_strategy,
|
||||
max_input_length=max_input_length,
|
||||
async_mode=async_mode,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
@ -13,12 +13,8 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config import defaults
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.index.operations.summarize_communities import (
|
||||
restore_community_hierarchy,
|
||||
summarize_communities,
|
||||
)
|
||||
from graphrag.index.operations.summarize_communities.community_reports_extractor.utils import (
|
||||
get_levels,
|
||||
)
|
||||
from graphrag.index.operations.summarize_communities_text.context_builder import (
|
||||
prep_community_report_context,
|
||||
prep_local_context,
|
||||
@ -46,36 +42,25 @@ async def create_final_community_reports_text(
|
||||
nodes_df = nodes_input.merge(entities_df, on="id")
|
||||
nodes = nodes_df.loc[nodes_df.loc[:, "community"] != -1]
|
||||
|
||||
max_input_length = summarization_strategy.get("max_input_length", 16_000)
|
||||
|
||||
# TEMP: forcing override of the prompt until we can put it into config
|
||||
summarization_strategy["extraction_prompt"] = COMMUNITY_REPORT_PROMPT
|
||||
# build initial local context for all communities
|
||||
|
||||
max_input_length = summarization_strategy.get(
|
||||
"max_input_length", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
|
||||
)
|
||||
|
||||
local_contexts = prep_local_context(
|
||||
communities, text_units, nodes, max_input_length
|
||||
)
|
||||
|
||||
community_hierarchy = restore_community_hierarchy(nodes)
|
||||
levels = get_levels(nodes)
|
||||
|
||||
level_contexts = []
|
||||
for level in levels:
|
||||
level_context = prep_community_report_context(
|
||||
local_context_df=local_contexts,
|
||||
community_hierarchy_df=community_hierarchy,
|
||||
level=level,
|
||||
max_tokens=summarization_strategy.get(
|
||||
"max_input_tokens", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
|
||||
),
|
||||
)
|
||||
level_contexts.append(level_context)
|
||||
|
||||
community_reports = await summarize_communities(
|
||||
nodes,
|
||||
local_contexts,
|
||||
level_contexts,
|
||||
prep_community_report_context,
|
||||
callbacks,
|
||||
cache,
|
||||
summarization_strategy,
|
||||
max_input_length=max_input_length,
|
||||
async_mode=async_mode,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
@ -30,6 +30,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prep_community_report_context(
|
||||
report_df: pd.DataFrame | None,
|
||||
community_hierarchy_df: pd.DataFrame,
|
||||
local_context_df: pd.DataFrame,
|
||||
level: int,
|
||||
@ -42,8 +43,6 @@ def prep_community_report_context(
|
||||
- Check if local context fits within the limit, if yes use local context
|
||||
- If local context exceeds the limit, iteratively replace local context with sub-community reports, starting from the biggest sub-community
|
||||
"""
|
||||
report_df = pd.DataFrame()
|
||||
|
||||
# Filter by community level
|
||||
level_context_df = local_context_df.loc[
|
||||
local_context_df.loc[:, schemas.COMMUNITY_LEVEL] == level
|
||||
@ -62,7 +61,7 @@ def prep_community_report_context(
|
||||
if invalid_context_df.empty:
|
||||
return valid_context_df
|
||||
|
||||
if report_df.empty:
|
||||
if report_df is None or report_df.empty:
|
||||
invalid_context_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
|
||||
invalid_context_df, max_tokens
|
||||
)
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
"""A module containing create_community_reports and load_strategy methods definition."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@ -12,6 +13,12 @@ from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.index.operations.summarize_communities.community_reports_extractor.utils import (
|
||||
get_levels,
|
||||
)
|
||||
from graphrag.index.operations.summarize_communities.restore_community_hierarchy import (
|
||||
restore_community_hierarchy,
|
||||
)
|
||||
from graphrag.index.operations.summarize_communities.typing import (
|
||||
CommunityReport,
|
||||
CommunityReportsStrategy,
|
||||
@ -24,11 +31,13 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def summarize_communities(
|
||||
nodes: pd.DataFrame,
|
||||
local_contexts,
|
||||
level_contexts,
|
||||
level_context_builder: Callable,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
strategy: dict,
|
||||
max_input_length: int,
|
||||
async_mode: AsyncType = AsyncType.AsyncIO,
|
||||
num_threads: int = 4,
|
||||
):
|
||||
@ -37,6 +46,20 @@ async def summarize_communities(
|
||||
tick = progress_ticker(callbacks.progress, len(local_contexts))
|
||||
runner = load_strategy(strategy["type"])
|
||||
|
||||
community_hierarchy = restore_community_hierarchy(nodes)
|
||||
levels = get_levels(nodes)
|
||||
|
||||
level_contexts = []
|
||||
for level in levels:
|
||||
level_context = level_context_builder(
|
||||
pd.DataFrame(reports),
|
||||
community_hierarchy_df=community_hierarchy,
|
||||
local_context_df=local_contexts,
|
||||
level=level,
|
||||
max_tokens=max_input_length,
|
||||
)
|
||||
level_contexts.append(level_context)
|
||||
|
||||
for level_context in level_contexts:
|
||||
|
||||
async def run_generate(record):
|
||||
|
||||
@ -76,10 +76,10 @@ def prep_local_context(
|
||||
|
||||
|
||||
def prep_community_report_context(
|
||||
local_context_df: pd.DataFrame,
|
||||
report_df: pd.DataFrame | None,
|
||||
community_hierarchy_df: pd.DataFrame,
|
||||
local_context_df: pd.DataFrame,
|
||||
level: int,
|
||||
report_df: pd.DataFrame | None = None,
|
||||
max_tokens: int = 16000,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
|
||||
@ -56,4 +56,5 @@ def _get_workflows_list(
|
||||
"create_final_communities",
|
||||
"create_final_text_units",
|
||||
"create_final_community_reports_text",
|
||||
"generate_text_embeddings",
|
||||
]
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
|
||||
from graphrag.config.enums import LLMType
|
||||
from graphrag.config.enums import AuthType, LLMType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
|
||||
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
|
||||
@ -31,7 +31,8 @@ def get_llm(config: GraphRagConfig) -> ChatOpenAI:
|
||||
api_key=default_llm_settings.api_key,
|
||||
azure_ad_token_provider=(
|
||||
get_bearer_token_provider(DefaultAzureCredential(), audience)
|
||||
if is_azure_client and not default_llm_settings.api_key
|
||||
if is_azure_client
|
||||
and default_llm_settings.auth_type == AuthType.AzureManagedIdentity
|
||||
else None
|
||||
),
|
||||
api_base=default_llm_settings.api_base,
|
||||
@ -65,7 +66,8 @@ def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
|
||||
api_key=embeddings_llm_settings.api_key,
|
||||
azure_ad_token_provider=(
|
||||
get_bearer_token_provider(DefaultAzureCredential(), audience)
|
||||
if is_azure_client and not embeddings_llm_settings.api_key
|
||||
if is_azure_client
|
||||
and embeddings_llm_settings.auth_type == AuthType.AzureManagedIdentity
|
||||
else None
|
||||
),
|
||||
api_base=embeddings_llm_settings.api_base,
|
||||
|
||||
2
tests/fixtures/min-csv/settings.yml
vendored
2
tests/fixtures/min-csv/settings.yml
vendored
@ -1,5 +1,6 @@
|
||||
models:
|
||||
default_chat_model:
|
||||
azure_auth_type: api_key
|
||||
type: ${GRAPHRAG_LLM_TYPE}
|
||||
api_key: ${GRAPHRAG_API_KEY}
|
||||
api_base: ${GRAPHRAG_API_BASE}
|
||||
@ -13,6 +14,7 @@ models:
|
||||
parallelization_stagger: 0.3
|
||||
async_mode: threaded
|
||||
default_embedding_model:
|
||||
azure_auth_type: api_key
|
||||
type: ${GRAPHRAG_EMBEDDING_TYPE}
|
||||
api_key: ${GRAPHRAG_API_KEY}
|
||||
api_base: ${GRAPHRAG_API_BASE}
|
||||
|
||||
2
tests/fixtures/text/settings.yml
vendored
2
tests/fixtures/text/settings.yml
vendored
@ -1,5 +1,6 @@
|
||||
models:
|
||||
default_chat_model:
|
||||
azure_auth_type: api_key
|
||||
type: ${GRAPHRAG_LLM_TYPE}
|
||||
api_key: ${GRAPHRAG_API_KEY}
|
||||
api_base: ${GRAPHRAG_API_BASE}
|
||||
@ -13,6 +14,7 @@ models:
|
||||
parallelization_stagger: 0.3
|
||||
async_mode: threaded
|
||||
default_embedding_model:
|
||||
azure_auth_type: api_key
|
||||
type: ${GRAPHRAG_EMBEDDING_TYPE}
|
||||
api_key: ${GRAPHRAG_API_KEY}
|
||||
api_base: ${GRAPHRAG_API_BASE}
|
||||
|
||||
@ -10,7 +10,7 @@ from pydantic import ValidationError
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.config.enums import AzureAuthType, LLMType
|
||||
from graphrag.config.enums import AuthType, LLMType
|
||||
from graphrag.config.load_config import load_config
|
||||
from tests.unit.config.utils import (
|
||||
DEFAULT_EMBEDDING_MODEL_CONFIG,
|
||||
@ -46,7 +46,7 @@ def test_missing_azure_api_key() -> None:
|
||||
model_config_missing_api_key = {
|
||||
defs.DEFAULT_CHAT_MODEL_ID: {
|
||||
"type": LLMType.AzureOpenAIChat,
|
||||
"azure_auth_type": AzureAuthType.APIKey,
|
||||
"auth_type": AuthType.APIKey,
|
||||
"model": defs.LLM_MODEL,
|
||||
"api_base": "some_api_base",
|
||||
"api_version": "some_api_version",
|
||||
@ -59,17 +59,31 @@ def test_missing_azure_api_key() -> None:
|
||||
create_graphrag_config({"models": model_config_missing_api_key})
|
||||
|
||||
# API Key not required for managed identity
|
||||
model_config_missing_api_key[defs.DEFAULT_CHAT_MODEL_ID]["azure_auth_type"] = (
|
||||
AzureAuthType.ManagedIdentity
|
||||
model_config_missing_api_key[defs.DEFAULT_CHAT_MODEL_ID]["auth_type"] = (
|
||||
AuthType.AzureManagedIdentity
|
||||
)
|
||||
create_graphrag_config({"models": model_config_missing_api_key})
|
||||
|
||||
|
||||
def test_conflicting_auth_type() -> None:
|
||||
model_config_invalid_auth_type = {
|
||||
defs.DEFAULT_CHAT_MODEL_ID: {
|
||||
"auth_type": AuthType.AzureManagedIdentity,
|
||||
"type": LLMType.OpenAIChat,
|
||||
"model": defs.LLM_MODEL,
|
||||
},
|
||||
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
create_graphrag_config({"models": model_config_invalid_auth_type})
|
||||
|
||||
|
||||
def test_conflicting_azure_api_key() -> None:
|
||||
model_config_conflicting_api_key = {
|
||||
defs.DEFAULT_CHAT_MODEL_ID: {
|
||||
"type": LLMType.AzureOpenAIChat,
|
||||
"azure_auth_type": AzureAuthType.ManagedIdentity,
|
||||
"auth_type": AuthType.AzureManagedIdentity,
|
||||
"model": defs.LLM_MODEL,
|
||||
"api_base": "some_api_base",
|
||||
"api_version": "some_api_version",
|
||||
@ -85,7 +99,7 @@ def test_conflicting_azure_api_key() -> None:
|
||||
|
||||
base_azure_model_config = {
|
||||
"type": LLMType.AzureOpenAIChat,
|
||||
"azure_auth_type": AzureAuthType.ManagedIdentity,
|
||||
"auth_type": AuthType.AzureManagedIdentity,
|
||||
"model": defs.LLM_MODEL,
|
||||
"api_base": "some_api_base",
|
||||
"api_version": "some_api_version",
|
||||
|
||||
@ -241,7 +241,7 @@ def assert_language_model_configs(
|
||||
actual: LanguageModelConfig, expected: LanguageModelConfig
|
||||
) -> None:
|
||||
assert actual.api_key == expected.api_key
|
||||
assert actual.azure_auth_type == expected.azure_auth_type
|
||||
assert actual.auth_type == expected.auth_type
|
||||
assert actual.type == expected.type
|
||||
assert actual.model == expected.model
|
||||
assert actual.encoding_model == expected.encoding_model
|
||||
|
||||
Loading…
Reference in New Issue
Block a user