Refactor config (#1593)

* Refactor config

- Add new ModelConfig to represent LLM settings
    - Combines LLMParameters, ParallelizationParameters, encoding_model, and async_mode
- Add top level models config that is a list of available LLM ModelConfigs
- Remove LLMConfig inheritance and delete LLMConfig
    - Replace the inheritance with a model_id reference to the ModelConfig listed in the top level models config
- Remove all fallbacks and hydration logic from create_graphrag_config
    - This removes the automatic env variable overrides
- Support env variables within config files using Templating
    - This requires "$" to be escaped with extra "$" so ".*\\.txt$" becomes ".*\\.txt$$"
- Update init content to initialize new config file with the ModelConfig structure

* Use dict of ModelConfig instead of list

* Add model validations and unit tests

* Fix ruff checks

* Add semversioner change

* Fix unit tests

* validate root_dir in pydantic model

* Rename ModelConfig to LanguageModelConfig

* Rename ModelConfigMissingError to LanguageModelConfigMissingError

* Add validationg for unexpected API keys

* Allow skipping pydantic validation for testing/mocking purposes.

* Add default lm configs to verb tests

* smoke test

* remove config from flows to fix llm arg mapping

* Fix embedding llm arg mapping

* Remove timestamp from smoke test outputs

* Remove unused "subworkflows" smoke test properties

* Add models to smoke test configs

* Update smoke test output path

* Send logs to logs folder

* Fix output path

* Fix csv test file pattern

* Update placeholder

* Format

* Instantiate default model configs

* Fix unit tests for config defaults

* Fix migration notebook

* Remove create_pipeline_config

* Remove several unused config models

* Remove indexing embedding and input configs

* Move embeddings function to config

* Remove skip_workflows

* Remove skip embeddings in favor of explicit naming

* fix unit test spelling mistake

* self.models[model_id] is already a language model. Remove redundant casting.

* update validation errors to instruct users to rerun graphrag init

* instantiate LanguageModelConfigs with validation

* skip validation in unit tests

* update verb tests to use default model settings instead of skipping validation

* test using llm settings

* cleanup verb tests

* remove unsafe default model config

* remove the ability to skip pydantic validation

* remove None union types when default values are set

* move vector_store from embeddings to top level of config and delete resolve_paths

* update vector store settings

* fix vector store and smoke tests

* fix serializing vector_store settings

* fix vector_store usage

* fix vector_store type

* support cli overrides for loading graphrag config

* rename storage to output

* Add --force flag to init

* Remove run_id and resume, fix Drift config assignment

* Ruff

---------

Co-authored-by: Nathan Evans <github@talkswithnumbers.com>
Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
Derek Worthen 2025-01-21 15:52:06 -08:00 committed by GitHub
parent 47adfe16f0
commit c644338bae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
104 changed files with 2251 additions and 3608 deletions

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Remove config inheritance, hydration, and automatic env var overlays."
}

View File

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
@ -25,7 +25,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
@ -37,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -45,19 +45,20 @@
"\n",
"from graphrag.config.load_config import load_config\n",
"from graphrag.config.resolve_path import resolve_paths\n",
"from graphrag.index.create_pipeline_config import create_pipeline_config\n",
"from graphrag.storage.factory import create_storage\n",
"from graphrag.storage.factory import StorageFactory\n",
"\n",
"# This first block does some config loading, path resolution, and translation that is normally done by the CLI/API when running a full workflow\n",
"config = load_config(Path(PROJECT_DIRECTORY))\n",
"resolve_paths(config)\n",
"pipeline_config = create_pipeline_config(config)\n",
"storage = create_storage(pipeline_config.storage)"
"storage_config = config.storage.model_dump() # type: ignore\n",
"storage = StorageFactory().create_storage(\n",
" storage_type=storage_config[\"type\"], kwargs=storage_config\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 69,
"metadata": {},
"outputs": [],
"source": [
@ -68,7 +69,7 @@
},
{
"cell_type": "code",
"execution_count": 63,
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
@ -97,7 +98,7 @@
},
{
"cell_type": "code",
"execution_count": 64,
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
@ -108,22 +109,16 @@
"# First we'll go through any parquet files that had model changes and update them\n",
"# The new data model may have removed excess columns as well, but we will only make the minimal changes required for compatibility\n",
"\n",
"final_documents = await load_table_from_storage(\n",
" \"create_final_documents.parquet\", storage\n",
")\n",
"final_text_units = await load_table_from_storage(\n",
" \"create_final_text_units.parquet\", storage\n",
")\n",
"final_entities = await load_table_from_storage(\"create_final_entities.parquet\", storage)\n",
"final_nodes = await load_table_from_storage(\"create_final_nodes.parquet\", storage)\n",
"final_documents = await load_table_from_storage(\"create_final_documents\", storage)\n",
"final_text_units = await load_table_from_storage(\"create_final_text_units\", storage)\n",
"final_entities = await load_table_from_storage(\"create_final_entities\", storage)\n",
"final_nodes = await load_table_from_storage(\"create_final_nodes\", storage)\n",
"final_relationships = await load_table_from_storage(\n",
" \"create_final_relationships.parquet\", storage\n",
")\n",
"final_communities = await load_table_from_storage(\n",
" \"create_final_communities.parquet\", storage\n",
" \"create_final_relationships\", storage\n",
")\n",
"final_communities = await load_table_from_storage(\"create_final_communities\", storage)\n",
"final_community_reports = await load_table_from_storage(\n",
" \"create_final_community_reports.parquet\", storage\n",
" \"create_final_community_reports\", storage\n",
")\n",
"\n",
"\n",
@ -183,44 +178,41 @@
" parent_df, on=\"community\", how=\"left\"\n",
" )\n",
"\n",
"await write_table_to_storage(final_documents, \"create_final_documents.parquet\", storage)\n",
"await write_table_to_storage(final_documents, \"create_final_documents\", storage)\n",
"await write_table_to_storage(final_text_units, \"create_final_text_units\", storage)\n",
"await write_table_to_storage(final_entities, \"create_final_entities\", storage)\n",
"await write_table_to_storage(final_nodes, \"create_final_nodes\", storage)\n",
"await write_table_to_storage(final_relationships, \"create_final_relationships\", storage)\n",
"await write_table_to_storage(final_communities, \"create_final_communities\", storage)\n",
"await write_table_to_storage(\n",
" final_text_units, \"create_final_text_units.parquet\", storage\n",
")\n",
"await write_table_to_storage(final_entities, \"create_final_entities.parquet\", storage)\n",
"await write_table_to_storage(final_nodes, \"create_final_nodes.parquet\", storage)\n",
"await write_table_to_storage(\n",
" final_relationships, \"create_final_relationships.parquet\", storage\n",
")\n",
"await write_table_to_storage(\n",
" final_communities, \"create_final_communities.parquet\", storage\n",
")\n",
"await write_table_to_storage(\n",
" final_community_reports, \"create_final_community_reports.parquet\", storage\n",
" final_community_reports, \"create_final_community_reports\", storage\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from graphrag.cache.factory import create_cache\n",
"from graphrag.cache.factory import CacheFactory\n",
"from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n",
"from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings\n",
"from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n",
"\n",
"# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n",
"# We'll construct the context and run this function flow directly to avoid everything else\n",
"\n",
"workflow = next(\n",
" (x for x in pipeline_config.workflows if x.name == \"generate_text_embeddings\"), None\n",
")\n",
"config = workflow.config\n",
"text_embed = config.get(\"text_embed\", {})\n",
"embedded_fields = config.get(\"embedded_fields\", {})\n",
"\n",
"embedded_fields = get_embedded_fields(config)\n",
"text_embed = get_embedding_settings(config)\n",
"callbacks = NoopWorkflowCallbacks()\n",
"cache = create_cache(pipeline_config.cache, PROJECT_DIRECTORY)\n",
"cache_config = config.cache.model_dump() # type: ignore\n",
"cache = CacheFactory().create_cache(\n",
" cache_type=cache_config[\"type\"], # type: ignore\n",
" root_dir=PROJECT_DIRECTORY,\n",
" kwargs=cache_config,\n",
")\n",
"\n",
"await generate_text_embeddings(\n",
" final_documents=None,\n",

View File

@ -11,7 +11,7 @@ Backwards compatibility is not guaranteed at this time.
import logging
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
from graphrag.callbacks.factory import create_pipeline_reporter
from graphrag.callbacks.reporting import create_pipeline_reporter
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import CacheType
from graphrag.config.models.graph_rag_config import GraphRagConfig
@ -24,8 +24,6 @@ log = logging.getLogger(__name__)
async def build_index(
config: GraphRagConfig,
run_id: str = "",
is_resume_run: bool = False,
memory_profile: bool = False,
callbacks: list[WorkflowCallbacks] | None = None,
progress_logger: ProgressLogger | None = None,
@ -36,10 +34,6 @@ async def build_index(
----------
config : GraphRagConfig
The configuration.
run_id : str
The run id. Creates a output directory with this name.
is_resume_run : bool default=False
Whether to resume a previous index run.
memory_profile : bool
Whether to enable memory profiling.
callbacks : list[WorkflowCallbacks] | None default=None
@ -52,11 +46,7 @@ async def build_index(
list[PipelineRunResult]
The list of pipeline run results
"""
is_update_run = bool(config.update_index_storage)
if is_resume_run and is_update_run:
msg = "Cannot resume and update a run at the same time."
raise ValueError(msg)
is_update_run = bool(config.update_index_output)
pipeline_cache = (
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
@ -78,7 +68,6 @@ async def build_index(
cache=pipeline_cache,
callbacks=callbacks,
logger=progress_logger,
run_id=run_id,
is_update_run=is_update_run,
):
outputs.append(output)

View File

@ -17,7 +17,7 @@ from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.llm.load_llm import load_llm
from graphrag.logger.print_progress import PrintProgressLogger
from graphrag.prompt_tune.defaults import MAX_TOKEN_COUNT
from graphrag.prompt_tune.defaults import MAX_TOKEN_COUNT, PROMPT_TUNING_MODEL_ID
from graphrag.prompt_tune.generator.community_report_rating import (
generate_community_report_rating,
)
@ -95,9 +95,11 @@ async def generate_indexing_prompts(
)
# Create LLM from config
# TODO: Expose way to specify Prompt Tuning model ID through config
default_llm_settings = config.get_language_model_config(PROMPT_TUNING_MODEL_ID)
llm = load_llm(
"prompt_tuning",
config.llm,
default_llm_settings,
cache=None,
callbacks=NoopWorkflowCallbacks(),
)
@ -120,6 +122,9 @@ async def generate_indexing_prompts(
)
entity_types = None
entity_extraction_llm_settings = config.get_language_model_config(
config.entity_extraction.model_id
)
if discover_entity_types:
logger.info("Generating entity types...")
entity_types = await generate_entity_types(
@ -127,7 +132,7 @@ async def generate_indexing_prompts(
domain=domain,
persona=persona,
docs=doc_list,
json_mode=config.llm.model_supports_json or False,
json_mode=entity_extraction_llm_settings.model_supports_json or False,
)
logger.info("Generating entity relationship examples...")
@ -147,7 +152,7 @@ async def generate_indexing_prompts(
examples=examples,
language=language,
json_mode=False, # config.llm.model_supports_json should be used, but these prompts are used in non-json mode by the index engine
encoding_model=config.encoding_model,
encoding_model=entity_extraction_llm_settings.encoding_model,
max_token_count=max_tokens,
min_examples_required=min_examples_required,
)

View File

@ -24,12 +24,13 @@ from typing import TYPE_CHECKING, Any
import pandas as pd
from pydantic import validate_call
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.config.embeddings import (
from graphrag.config.embeddings import (
community_full_content_embedding,
create_collection_name,
entity_description_embedding,
text_unit_text_embedding,
)
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.logger.print_progress import PrintProgressLogger
from graphrag.query.factory import (
get_basic_search_engine,
@ -47,7 +48,6 @@ from graphrag.query.indexer_adapters import (
read_indexer_text_units,
)
from graphrag.utils.cli import redact
from graphrag.utils.embeddings import create_collection_name
from graphrag.vector_stores.base import BaseVectorStore
from graphrag.vector_stores.factory import VectorStoreFactory
@ -244,7 +244,7 @@ async def local_search(
------
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
vector_store_args = config.vector_store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
description_embedding_store = _get_embedding_store(
@ -310,7 +310,7 @@ async def local_search_streaming(
------
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
vector_store_args = config.vector_store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
description_embedding_store = _get_embedding_store(
@ -381,7 +381,7 @@ async def drift_search_streaming(
------
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
vector_store_args = config.vector_store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
description_embedding_store = _get_embedding_store(
@ -465,7 +465,7 @@ async def drift_search(
------
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
vector_store_args = config.vector_store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
description_embedding_store = _get_embedding_store(
@ -531,7 +531,7 @@ async def basic_search(
------
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
vector_store_args = config.vector_store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
description_embedding_store = _get_embedding_store(
@ -576,7 +576,7 @@ async def basic_search_streaming(
------
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
vector_store_args = config.vector_store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
description_embedding_store = _get_embedding_store(

View File

@ -1,45 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Create a pipeline logger."""
from pathlib import Path
from typing import cast
from graphrag.callbacks.blob_workflow_callbacks import BlobWorkflowCallbacks
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
from graphrag.callbacks.file_workflow_callbacks import FileWorkflowCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import ReportingType
from graphrag.index.config.reporting import (
PipelineBlobReportingConfig,
PipelineFileReportingConfig,
PipelineReportingConfig,
)
def create_pipeline_reporter(
config: PipelineReportingConfig | None, root_dir: str | None
) -> WorkflowCallbacks:
"""Create a logger for the given pipeline config."""
config = config or PipelineFileReportingConfig(base_dir="logs")
match config.type:
case ReportingType.file:
config = cast("PipelineFileReportingConfig", config)
return FileWorkflowCallbacks(
str(Path(root_dir or "") / (config.base_dir or ""))
)
case ReportingType.console:
return ConsoleWorkflowCallbacks()
case ReportingType.blob:
config = cast("PipelineBlobReportingConfig", config)
return BlobWorkflowCallbacks(
config.connection_string,
config.container_name,
base_dir=config.base_dir,
storage_account_blob_url=config.storage_account_blob_url,
)
case _:
msg = f"Unknown reporting type: {config.type}"
raise ValueError(msg)

View File

@ -5,12 +5,19 @@
from __future__ import annotations
from typing import Generic, Literal, TypeVar
from pathlib import Path
from typing import TYPE_CHECKING, Generic, Literal, TypeVar, cast
from pydantic import BaseModel, Field
from graphrag.callbacks.blob_workflow_callbacks import BlobWorkflowCallbacks
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
from graphrag.callbacks.file_workflow_callbacks import FileWorkflowCallbacks
from graphrag.config.enums import ReportingType
if TYPE_CHECKING:
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
T = TypeVar("T")
@ -74,3 +81,30 @@ PipelineReportingConfigTypes = (
| PipelineConsoleReportingConfig
| PipelineBlobReportingConfig
)
def create_pipeline_reporter(
config: PipelineReportingConfig | None, root_dir: str | None
) -> WorkflowCallbacks:
"""Create a logger for the given pipeline config."""
config = config or PipelineFileReportingConfig(base_dir="logs")
match config.type:
case ReportingType.file:
config = cast("PipelineFileReportingConfig", config)
return FileWorkflowCallbacks(
str(Path(root_dir or "") / (config.base_dir or ""))
)
case ReportingType.console:
return ConsoleWorkflowCallbacks()
case ReportingType.blob:
config = cast("PipelineBlobReportingConfig", config)
return BlobWorkflowCallbacks(
config.connection_string,
config.container_name,
base_dir=config.base_dir,
storage_account_blob_url=config.storage_account_blob_url,
)
case _:
msg = f"Unknown reporting type: {config.type}"
raise ValueError(msg)

View File

@ -6,7 +6,6 @@
import asyncio
import logging
import sys
import time
import warnings
from pathlib import Path
@ -14,7 +13,6 @@ import graphrag.api as api
from graphrag.config.enums import CacheType
from graphrag.config.load_config import load_config
from graphrag.config.logging import enable_logging_with_config
from graphrag.config.resolve_path import resolve_paths
from graphrag.index.validate_config import validate_config_names
from graphrag.logger.base import ProgressLogger
from graphrag.logger.factory import LoggerFactory, LoggerType
@ -66,7 +64,6 @@ def _register_signal_handlers(logger: ProgressLogger):
def index_cli(
root_dir: Path,
verbose: bool,
resume: str | None,
memprofile: bool,
cache: bool,
logger: LoggerType,
@ -76,18 +73,20 @@ def index_cli(
output_dir: Path | None,
):
"""Run the pipeline with the given config."""
config = load_config(root_dir, config_filepath)
cli_overrides = {}
if output_dir:
cli_overrides["output.base_dir"] = str(output_dir)
cli_overrides["reporting.base_dir"] = str(output_dir)
config = load_config(root_dir, config_filepath, cli_overrides)
_run_index(
config=config,
verbose=verbose,
resume=resume,
memprofile=memprofile,
cache=cache,
logger=logger,
dry_run=dry_run,
skip_validation=skip_validation,
output_dir=output_dir,
)
@ -102,51 +101,44 @@ def update_cli(
output_dir: Path | None,
):
"""Run the pipeline with the given config."""
config = load_config(root_dir, config_filepath)
cli_overrides = {}
if output_dir:
cli_overrides["output.base_dir"] = str(output_dir)
cli_overrides["reporting.base_dir"] = str(output_dir)
config = load_config(root_dir, config_filepath, cli_overrides)
# Check if update storage exist, if not configure it with default values
if not config.update_index_storage:
from graphrag.config.defaults import STORAGE_TYPE, UPDATE_STORAGE_BASE_DIR
from graphrag.config.models.storage_config import StorageConfig
# Check if update output exist, if not configure it with default values
if not config.update_index_output:
from graphrag.config.defaults import OUTPUT_TYPE, UPDATE_OUTPUT_BASE_DIR
from graphrag.config.models.output_config import OutputConfig
config.update_index_storage = StorageConfig(
type=STORAGE_TYPE,
base_dir=UPDATE_STORAGE_BASE_DIR,
config.update_index_output = OutputConfig(
type=OUTPUT_TYPE,
base_dir=UPDATE_OUTPUT_BASE_DIR,
)
_run_index(
config=config,
verbose=verbose,
resume=False,
memprofile=memprofile,
cache=cache,
logger=logger,
dry_run=False,
skip_validation=skip_validation,
output_dir=output_dir,
)
def _run_index(
config,
verbose,
resume,
memprofile,
cache,
logger,
dry_run,
skip_validation,
output_dir,
):
progress_logger = LoggerFactory().create_logger(logger)
info, error, success = _logger(progress_logger)
run_id = resume or time.strftime("%Y%m%d-%H%M%S")
config.storage.base_dir = str(output_dir) if output_dir else config.storage.base_dir
config.reporting.base_dir = (
str(output_dir) if output_dir else config.reporting.base_dir
)
resolve_paths(config, run_id)
if not cache:
config.cache.type = CacheType.none
@ -163,7 +155,7 @@ def _run_index(
if skip_validation:
validate_config_names(progress_logger, config)
info(f"Starting pipeline run for: {run_id}, {dry_run=}", verbose)
info(f"Starting pipeline run. {dry_run=}", verbose)
info(
f"Using default configuration: {redact(config.model_dump())}",
verbose,
@ -178,8 +170,6 @@ def _run_index(
outputs = asyncio.run(
api.build_index(
config=config,
run_id=run_id,
is_resume_run=bool(resume),
memory_profile=memprofile,
progress_logger=progress_logger,
)

View File

@ -29,8 +29,22 @@ from graphrag.prompts.query.local_search_system_prompt import LOCAL_SEARCH_SYSTE
from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT
def initialize_project_at(path: Path) -> None:
"""Initialize the project at the given path."""
def initialize_project_at(path: Path, force: bool) -> None:
"""
Initialize the project at the given path.
Parameters
----------
path : Path
The path at which to initialize the project.
force : bool
Whether to force initialization even if the project already exists.
Raises
------
ValueError
If the project already exists and force is False.
"""
progress_logger = LoggerFactory().create_logger(LoggerType.RICH)
progress_logger.info(f"Initializing project at {path}") # noqa: G004
root = Path(path)
@ -38,7 +52,7 @@ def initialize_project_at(path: Path) -> None:
root.mkdir(parents=True, exist_ok=True)
settings_yaml = root / "settings.yaml"
if settings_yaml.exists():
if settings_yaml.exists() and not force:
msg = f"Project already initialized at {root}"
raise ValueError(msg)
@ -46,7 +60,7 @@ def initialize_project_at(path: Path) -> None:
file.write(INIT_YAML.encode(encoding="utf-8", errors="strict"))
dotenv = root / ".env"
if not dotenv.exists():
if not dotenv.exists() or force:
with dotenv.open("wb") as file:
file.write(INIT_DOTENV.encode(encoding="utf-8", errors="strict"))
@ -71,6 +85,6 @@ def initialize_project_at(path: Path) -> None:
for name, content in prompts.items():
prompt_file = prompts_dir / f"{name}.txt"
if not prompt_file.exists():
if not prompt_file.exists() or force:
with prompt_file.open("wb") as file:
file.write(content.encode(encoding="utf-8", errors="strict"))

View File

@ -109,11 +109,15 @@ def _initialize_cli(
),
),
],
force: Annotated[
bool,
typer.Option(help="Force initialization even if the project already exists."),
] = False,
):
"""Generate a default configuration file."""
from graphrag.cli.initialize import initialize_project_at
initialize_project_at(path=root)
initialize_project_at(path=root, force=force)
@app.command("index")
@ -143,9 +147,6 @@ def _index_cli(
memprofile: Annotated[
bool, typer.Option(help="Run the indexing pipeline with memory profiling")
] = False,
resume: Annotated[
str | None, typer.Option(help="Resume a given indexing run")
] = None,
logger: Annotated[
LoggerType, typer.Option(help="The progress logger to use.")
] = LoggerType.RICH,
@ -165,7 +166,7 @@ def _index_cli(
output: Annotated[
Path | None,
typer.Option(
help="Indexing pipeline output directory. Overrides storage.base_dir in the configuration file.",
help="Indexing pipeline output directory. Overrides output.base_dir in the configuration file.",
dir_okay=True,
writable=True,
resolve_path=True,
@ -178,7 +179,6 @@ def _index_cli(
index_cli(
root_dir=root,
verbose=verbose,
resume=resume,
memprofile=memprofile,
cache=cache,
logger=LoggerType(logger),
@ -226,7 +226,7 @@ def _update_cli(
output: Annotated[
Path | None,
typer.Option(
help="Indexing pipeline output directory. Overrides storage.base_dir in the configuration file.",
help="Indexing pipeline output directory. Overrides output.base_dir in the configuration file.",
dir_okay=True,
writable=True,
resolve_path=True,
@ -236,7 +236,7 @@ def _update_cli(
"""
Update an existing knowledge graph index.
Applies a default storage configuration (if not provided by config), saving the new index to the local file system in the `update_output` folder.
Applies a default output configuration (if not provided by config), saving the new index to the local file system in the `update_output` folder.
"""
from graphrag.cli.index import update_cli

View File

@ -12,8 +12,6 @@ import pandas as pd
import graphrag.api as api
from graphrag.config.load_config import load_config
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.resolve_path import resolve_paths
from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.logger.print_progress import PrintProgressLogger
from graphrag.storage.factory import StorageFactory
from graphrag.utils.storage import load_table_from_storage, storage_has_table
@ -36,9 +34,10 @@ def run_global_search(
Loads index files required for global search and calls the Query API.
"""
root = root_dir.resolve()
config = load_config(root, config_filepath)
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
resolve_paths(config)
cli_overrides = {}
if data_dir:
cli_overrides["output.base_dir"] = str(data_dir)
config = load_config(root, config_filepath, cli_overrides)
dataframe_dict = _resolve_output_files(
config=config,
@ -120,9 +119,10 @@ def run_local_search(
Loads index files required for local search and calls the Query API.
"""
root = root_dir.resolve()
config = load_config(root, config_filepath)
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
resolve_paths(config)
cli_overrides = {}
if data_dir:
cli_overrides["output.base_dir"] = str(data_dir)
config = load_config(root, config_filepath, cli_overrides)
dataframe_dict = _resolve_output_files(
config=config,
@ -211,9 +211,10 @@ def run_drift_search(
Loads index files required for local search and calls the Query API.
"""
root = root_dir.resolve()
config = load_config(root, config_filepath)
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
resolve_paths(config)
cli_overrides = {}
if data_dir:
cli_overrides["output.base_dir"] = str(data_dir)
config = load_config(root, config_filepath, cli_overrides)
dataframe_dict = _resolve_output_files(
config=config,
@ -296,9 +297,10 @@ def run_basic_search(
Loads index files required for basic search and calls the Query API.
"""
root = root_dir.resolve()
config = load_config(root, config_filepath)
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
resolve_paths(config)
cli_overrides = {}
if data_dir:
cli_overrides["output.base_dir"] = str(data_dir)
config = load_config(root, config_filepath, cli_overrides)
dataframe_dict = _resolve_output_files(
config=config,
@ -352,10 +354,9 @@ def _resolve_output_files(
) -> dict[str, pd.DataFrame]:
"""Read indexing output files to a dataframe dict."""
dataframe_dict = {}
pipeline_config = create_pipeline_config(config)
storage_config = pipeline_config.storage.model_dump() # type: ignore
output_config = config.output.model_dump() # type: ignore
storage_obj = StorageFactory().create_storage(
storage_type=storage_config["type"], kwargs=storage_config
storage_type=output_config["type"], kwargs=output_config
)
for name in output_list:
df_value = asyncio.run(load_table_from_storage(name=name, storage=storage_obj))

View File

@ -1,179 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Load a GraphRagConfiguration from a file."""
import json
from abc import ABC, abstractmethod
from pathlib import Path
import yaml
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.models.graph_rag_config import GraphRagConfig
_default_config_files = ["settings.yaml", "settings.yml", "settings.json"]
def search_for_config_in_root_dir(root: str | Path) -> Path | None:
"""Resolve the config path from the given root directory.
Parameters
----------
root : str | Path
The path to the root directory containing the config file.
Searches for a default config file (settings.{yaml,yml,json}).
Returns
-------
Path | None
returns a Path if there is a config in the root directory
Otherwise returns None.
"""
root = Path(root)
if not root.is_dir():
msg = f"Invalid config path: {root} is not a directory"
raise FileNotFoundError(msg)
for file in _default_config_files:
if (root / file).is_file():
return root / file
return None
class ConfigFileLoader(ABC):
"""Base class for loading a configuration from a file."""
@abstractmethod
def load_config(self, config_path: str | Path) -> GraphRagConfig:
"""Load configuration from a file."""
raise NotImplementedError
class ConfigYamlLoader(ConfigFileLoader):
"""Load a configuration from a yaml file."""
def load_config(self, config_path: str | Path) -> GraphRagConfig:
"""Load a configuration from a yaml file.
Parameters
----------
config_path : str | Path
The path to the yaml file to load.
Returns
-------
GraphRagConfig
The loaded configuration.
Raises
------
ValueError
If the file extension is not .yaml or .yml.
FileNotFoundError
If the config file is not found.
"""
config_path = Path(config_path)
if config_path.suffix not in [".yaml", ".yml"]:
msg = f"Invalid file extension for loading yaml config from: {config_path!s}. Expected .yaml or .yml"
raise ValueError(msg)
root_dir = str(config_path.parent)
if not config_path.is_file():
msg = f"Config file not found: {config_path}"
raise FileNotFoundError(msg)
with config_path.open("rb") as file:
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root_dir)
class ConfigJsonLoader(ConfigFileLoader):
"""Load a configuration from a json file."""
def load_config(self, config_path: str | Path) -> GraphRagConfig:
"""Load a configuration from a json file.
Parameters
----------
config_path : str | Path
The path to the json file to load.
Returns
-------
GraphRagConfig
The loaded configuration.
Raises
------
ValueError
If the file extension is not .json.
FileNotFoundError
If the config file is not found.
"""
config_path = Path(config_path)
root_dir = str(config_path.parent)
if config_path.suffix != ".json":
msg = f"Invalid file extension for loading json config from: {config_path!s}. Expected .json"
raise ValueError(msg)
if not config_path.is_file():
msg = f"Config file not found: {config_path}"
raise FileNotFoundError(msg)
with config_path.open("rb") as file:
data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root_dir)
def get_config_file_loader(config_path: str | Path) -> ConfigFileLoader:
"""Config File Loader Factory.
Parameters
----------
config_path : str | Path
The path to the config file.
Returns
-------
ConfigFileLoader
The config file loader for the provided config file.
Raises
------
ValueError
If the config file extension is not supported.
"""
config_path = Path(config_path)
ext = config_path.suffix
match ext:
case ".yaml" | ".yml":
return ConfigYamlLoader()
case ".json":
return ConfigJsonLoader()
case _:
msg = f"Unsupported config file extension: {ext}"
raise ValueError(msg)
def load_config_from_file(config_path: str | Path) -> GraphRagConfig:
"""Load a configuration from a file.
Parameters
----------
config_path : str | Path
The path to the configuration file.
Supports .yaml, .yml, and .json config files.
Returns
-------
GraphRagConfig
The loaded configuration.
Raises
------
ValueError
If the file extension is not supported.
FileNotFoundError
If the config file is not found.
"""
loader = get_config_file_loader(config_path)
return loader.load_config(config_path)

View File

@ -3,777 +3,41 @@
"""Parameterization settings for the default configuration, loaded from environment variables."""
import os
from enum import Enum
from pathlib import Path
from typing import Any, cast
from typing import Any
from environs import Env
import graphrag.config.defaults as defs
from graphrag.config.enums import (
AsyncType,
CacheType,
InputFileType,
InputType,
LLMType,
ReportingType,
StorageType,
TextEmbeddingTarget,
)
from graphrag.config.environment_reader import EnvironmentReader
from graphrag.config.errors import (
ApiKeyMissingError,
AzureApiBaseMissingError,
AzureDeploymentNameMissingError,
)
from graphrag.config.models.basic_search_config import BasicSearchConfig
from graphrag.config.models.cache_config import CacheConfig
from graphrag.config.models.chunking_config import ChunkingConfig, ChunkStrategyType
from graphrag.config.models.claim_extraction_config import ClaimExtractionConfig
from graphrag.config.models.cluster_graph_config import ClusterGraphConfig
from graphrag.config.models.community_reports_config import CommunityReportsConfig
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
from graphrag.config.models.entity_extraction_config import EntityExtractionConfig
from graphrag.config.models.global_search_config import GlobalSearchConfig
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.input_config import InputConfig
from graphrag.config.models.llm_parameters import LLMParameters
from graphrag.config.models.local_search_config import LocalSearchConfig
from graphrag.config.models.parallelization_parameters import ParallelizationParameters
from graphrag.config.models.reporting_config import ReportingConfig
from graphrag.config.models.snapshots_config import SnapshotsConfig
from graphrag.config.models.storage_config import StorageConfig
from graphrag.config.models.summarize_descriptions_config import (
SummarizeDescriptionsConfig,
)
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
from graphrag.config.models.umap_config import UmapConfig
from graphrag.config.read_dotenv import read_dotenv
def create_graphrag_config(
values: dict[str, Any] | None = None, root_dir: str | None = None
values: dict[str, Any] | None = None,
root_dir: str | None = None,
) -> GraphRagConfig:
"""Load Configuration Parameters from a dictionary."""
"""Load Configuration Parameters from a dictionary.
Parameters
----------
values : dict[str, Any] | None
Dictionary of configuration values to pass into pydantic model.
root_dir : str | None
Root directory for the project.
skip_validation : bool
Skip pydantic model validation of the configuration.
This is useful for testing and mocking purposes but
should not be used in the core code or API.
Returns
-------
GraphRagConfig
The configuration object.
Raises
------
ValidationError
If the configuration values do not satisfy pydantic validation.
"""
values = values or {}
root_dir = root_dir or str(Path.cwd())
env = _make_env(root_dir)
_token_replace(cast("dict", values))
reader = EnvironmentReader(env)
def hydrate_async_type(input: dict[str, Any], base: AsyncType) -> AsyncType:
value = input.get(Fragment.async_mode)
return AsyncType(value) if value else base
def hydrate_llm_params(
config: dict[str, Any], base: LLMParameters
) -> LLMParameters:
with reader.use(config.get("llm")):
llm_type = reader.str(Fragment.type)
llm_type = LLMType(llm_type) if llm_type else base.type
api_key = reader.str(Fragment.api_key) or base.api_key
api_base = reader.str(Fragment.api_base) or base.api_base
audience = reader.str(Fragment.audience) or base.audience
deployment_name = (
reader.str(Fragment.deployment_name) or base.deployment_name
)
encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model
if api_key is None and not _is_azure(llm_type):
raise ApiKeyMissingError
if _is_azure(llm_type):
if api_base is None:
raise AzureApiBaseMissingError
if deployment_name is None:
raise AzureDeploymentNameMissingError
sleep_on_rate_limit = reader.bool(Fragment.sleep_recommendation)
if sleep_on_rate_limit is None:
sleep_on_rate_limit = base.sleep_on_rate_limit_recommendation
return LLMParameters(
api_key=api_key,
type=llm_type,
api_base=api_base,
api_version=reader.str(Fragment.api_version) or base.api_version,
organization=reader.str("organization") or base.organization,
proxy=reader.str("proxy") or base.proxy,
model=reader.str("model") or base.model,
encoding_model=encoding_model,
max_tokens=reader.int(Fragment.max_tokens) or base.max_tokens,
temperature=reader.float(Fragment.temperature) or base.temperature,
top_p=reader.float(Fragment.top_p) or base.top_p,
n=reader.int(Fragment.n) or base.n,
model_supports_json=reader.bool(Fragment.model_supports_json)
or base.model_supports_json,
request_timeout=reader.float(Fragment.request_timeout)
or base.request_timeout,
audience=audience,
deployment_name=deployment_name,
tokens_per_minute=reader.int("tokens_per_minute", Fragment.tpm)
or base.tokens_per_minute,
requests_per_minute=reader.int("requests_per_minute", Fragment.rpm)
or base.requests_per_minute,
max_retries=reader.int(Fragment.max_retries) or base.max_retries,
max_retry_wait=reader.float(Fragment.max_retry_wait)
or base.max_retry_wait,
sleep_on_rate_limit_recommendation=sleep_on_rate_limit,
concurrent_requests=reader.int(Fragment.concurrent_requests)
or base.concurrent_requests,
)
def hydrate_embeddings_params(
config: dict[str, Any], base: LLMParameters
) -> LLMParameters:
with reader.use(config.get("llm")):
api_type = reader.str(Fragment.type) or defs.EMBEDDING_TYPE
api_type = LLMType(api_type) if api_type else defs.LLM_TYPE
api_key = reader.str(Fragment.api_key) or base.api_key
# Account for various permutations of config settings such as:
# - same api_bases for LLM and embeddings (both Azure)
# - different api_bases for LLM and embeddings (both Azure)
# - LLM uses Azure OpenAI, while embeddings uses base OpenAI (this one is important)
# - LLM uses Azure OpenAI, while embeddings uses third-party OpenAI-like API
api_base = (
reader.str(Fragment.api_base) or base.api_base
if _is_azure(api_type)
else reader.str(Fragment.api_base)
)
api_version = (
reader.str(Fragment.api_version) or base.api_version
if _is_azure(api_type)
else reader.str(Fragment.api_version)
)
api_organization = reader.str("organization") or base.organization
api_proxy = reader.str("proxy") or base.proxy
audience = reader.str(Fragment.audience) or base.audience
deployment_name = reader.str(Fragment.deployment_name)
encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model
if api_key is None and not _is_azure(api_type):
raise ApiKeyMissingError(embedding=True)
if _is_azure(api_type):
if api_base is None:
raise AzureApiBaseMissingError(embedding=True)
if deployment_name is None:
raise AzureDeploymentNameMissingError(embedding=True)
sleep_on_rate_limit = reader.bool(Fragment.sleep_recommendation)
if sleep_on_rate_limit is None:
sleep_on_rate_limit = base.sleep_on_rate_limit_recommendation
return LLMParameters(
api_key=api_key,
type=api_type,
api_base=api_base,
api_version=api_version,
organization=api_organization,
proxy=api_proxy,
model=reader.str(Fragment.model) or defs.EMBEDDING_MODEL,
encoding_model=encoding_model,
request_timeout=reader.float(Fragment.request_timeout)
or defs.LLM_REQUEST_TIMEOUT,
audience=audience,
deployment_name=deployment_name,
tokens_per_minute=reader.int("tokens_per_minute", Fragment.tpm)
or defs.LLM_TOKENS_PER_MINUTE,
requests_per_minute=reader.int("requests_per_minute", Fragment.rpm)
or defs.LLM_REQUESTS_PER_MINUTE,
max_retries=reader.int(Fragment.max_retries) or defs.LLM_MAX_RETRIES,
max_retry_wait=reader.float(Fragment.max_retry_wait)
or defs.LLM_MAX_RETRY_WAIT,
sleep_on_rate_limit_recommendation=sleep_on_rate_limit,
concurrent_requests=reader.int(Fragment.concurrent_requests)
or defs.LLM_CONCURRENT_REQUESTS,
)
def hydrate_parallelization_params(
config: dict[str, Any], base: ParallelizationParameters
) -> ParallelizationParameters:
with reader.use(config.get("parallelization")):
return ParallelizationParameters(
num_threads=reader.int("num_threads", Fragment.thread_count)
or base.num_threads,
stagger=reader.float("stagger", Fragment.thread_stagger)
or base.stagger,
)
fallback_oai_key = env("OPENAI_API_KEY", env("AZURE_OPENAI_API_KEY", None))
fallback_oai_org = env("OPENAI_ORG_ID", None)
fallback_oai_base = env("OPENAI_BASE_URL", None)
fallback_oai_version = env("OPENAI_API_VERSION", None)
with reader.envvar_prefix(Section.graphrag), reader.use(values):
async_mode = reader.str(Fragment.async_mode)
async_mode = AsyncType(async_mode) if async_mode else defs.ASYNC_MODE
fallback_oai_key = reader.str(Fragment.api_key) or fallback_oai_key
fallback_oai_org = reader.str(Fragment.api_organization) or fallback_oai_org
fallback_oai_base = reader.str(Fragment.api_base) or fallback_oai_base
fallback_oai_version = reader.str(Fragment.api_version) or fallback_oai_version
fallback_oai_proxy = reader.str(Fragment.api_proxy)
global_encoding_model = (
reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL
)
with reader.envvar_prefix(Section.llm):
with reader.use(values.get("llm")):
llm_type = reader.str(Fragment.type)
llm_type = LLMType(llm_type) if llm_type else defs.LLM_TYPE
api_key = reader.str(Fragment.api_key) or fallback_oai_key
api_organization = (
reader.str(Fragment.api_organization) or fallback_oai_org
)
api_base = reader.str(Fragment.api_base) or fallback_oai_base
api_version = reader.str(Fragment.api_version) or fallback_oai_version
api_proxy = reader.str(Fragment.api_proxy) or fallback_oai_proxy
audience = reader.str(Fragment.audience)
deployment_name = reader.str(Fragment.deployment_name)
encoding_model = (
reader.str(Fragment.encoding_model) or global_encoding_model
)
if api_key is None and not _is_azure(llm_type):
raise ApiKeyMissingError
if _is_azure(llm_type):
if api_base is None:
raise AzureApiBaseMissingError
if deployment_name is None:
raise AzureDeploymentNameMissingError
sleep_on_rate_limit = reader.bool(Fragment.sleep_recommendation)
if sleep_on_rate_limit is None:
sleep_on_rate_limit = defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION
llm_model = LLMParameters(
api_key=api_key,
api_base=api_base,
api_version=api_version,
organization=api_organization,
proxy=api_proxy,
type=llm_type,
model=reader.str(Fragment.model) or defs.LLM_MODEL,
encoding_model=encoding_model,
max_tokens=reader.int(Fragment.max_tokens) or defs.LLM_MAX_TOKENS,
temperature=reader.float(Fragment.temperature)
or defs.LLM_TEMPERATURE,
top_p=reader.float(Fragment.top_p) or defs.LLM_TOP_P,
n=reader.int(Fragment.n) or defs.LLM_N,
model_supports_json=reader.bool(Fragment.model_supports_json),
request_timeout=reader.float(Fragment.request_timeout)
or defs.LLM_REQUEST_TIMEOUT,
audience=audience,
deployment_name=deployment_name,
tokens_per_minute=reader.int(Fragment.tpm)
or defs.LLM_TOKENS_PER_MINUTE,
requests_per_minute=reader.int(Fragment.rpm)
or defs.LLM_REQUESTS_PER_MINUTE,
max_retries=reader.int(Fragment.max_retries)
or defs.LLM_MAX_RETRIES,
max_retry_wait=reader.float(Fragment.max_retry_wait)
or defs.LLM_MAX_RETRY_WAIT,
sleep_on_rate_limit_recommendation=sleep_on_rate_limit,
concurrent_requests=reader.int(Fragment.concurrent_requests)
or defs.LLM_CONCURRENT_REQUESTS,
)
with reader.use(values.get("parallelization")):
llm_parallelization_model = ParallelizationParameters(
stagger=reader.float("stagger", Fragment.thread_stagger)
or defs.PARALLELIZATION_STAGGER,
num_threads=reader.int("num_threads", Fragment.thread_count)
or defs.PARALLELIZATION_NUM_THREADS,
)
embeddings_config = values.get("embeddings") or {}
with reader.envvar_prefix(Section.embedding), reader.use(embeddings_config):
embeddings_target = reader.str("target")
# TODO: remove the type ignore annotations below once the new config engine has been refactored
embeddings_model = TextEmbeddingConfig(
llm=hydrate_embeddings_params(embeddings_config, llm_model), # type: ignore
parallelization=hydrate_parallelization_params(
embeddings_config, # type: ignore
llm_parallelization_model, # type: ignore
),
vector_store=embeddings_config.get("vector_store", None),
async_mode=hydrate_async_type(embeddings_config, async_mode), # type: ignore
target=(
TextEmbeddingTarget(embeddings_target)
if embeddings_target
else defs.EMBEDDING_TARGET
),
batch_size=reader.int("batch_size") or defs.EMBEDDING_BATCH_SIZE,
batch_max_tokens=reader.int("batch_max_tokens")
or defs.EMBEDDING_BATCH_MAX_TOKENS,
skip=reader.list("skip") or [],
)
with (
reader.envvar_prefix(Section.node2vec),
reader.use(values.get("embed_graph")),
):
use_lcc = reader.bool("use_lcc")
embed_graph_model = EmbedGraphConfig(
enabled=reader.bool(Fragment.enabled) or defs.NODE2VEC_ENABLED,
dimensions=reader.int("dimensions") or defs.NODE2VEC_DIMENSIONS,
num_walks=reader.int("num_walks") or defs.NODE2VEC_NUM_WALKS,
walk_length=reader.int("walk_length") or defs.NODE2VEC_WALK_LENGTH,
window_size=reader.int("window_size") or defs.NODE2VEC_WINDOW_SIZE,
iterations=reader.int("iterations") or defs.NODE2VEC_ITERATIONS,
random_seed=reader.int("random_seed") or defs.NODE2VEC_RANDOM_SEED,
use_lcc=use_lcc if use_lcc is not None else defs.USE_LCC,
)
with reader.envvar_prefix(Section.input), reader.use(values.get("input")):
input_type = reader.str("type")
file_type = reader.str(Fragment.file_type)
input_model = InputConfig(
file_type=(
InputFileType(file_type) if file_type else defs.INPUT_FILE_TYPE
),
type=(InputType(input_type) if input_type else defs.INPUT_TYPE),
encoding=reader.str("file_encoding", Fragment.encoding)
or defs.INPUT_FILE_ENCODING,
base_dir=reader.str(Fragment.base_dir) or defs.INPUT_BASE_DIR,
file_pattern=reader.str("file_pattern")
or (
defs.INPUT_TEXT_PATTERN
if file_type == InputFileType.text
else defs.INPUT_CSV_PATTERN
),
source_column=reader.str("source_column"),
timestamp_column=reader.str("timestamp_column"),
timestamp_format=reader.str("timestamp_format"),
text_column=reader.str("text_column") or defs.INPUT_TEXT_COLUMN,
title_column=reader.str("title_column"),
document_attribute_columns=reader.list("document_attribute_columns")
or [],
connection_string=reader.str(Fragment.conn_string),
storage_account_blob_url=reader.str(Fragment.storage_account_blob_url),
container_name=reader.str(Fragment.container_name),
)
with reader.envvar_prefix(Section.cache), reader.use(values.get("cache")):
c_type = reader.str(Fragment.type)
cache_model = CacheConfig(
type=CacheType(c_type) if c_type else defs.CACHE_TYPE,
connection_string=reader.str(Fragment.conn_string),
storage_account_blob_url=reader.str(Fragment.storage_account_blob_url),
container_name=reader.str(Fragment.container_name),
base_dir=reader.str(Fragment.base_dir) or defs.CACHE_BASE_DIR,
cosmosdb_account_url=reader.str(Fragment.cosmosdb_account_url),
)
with (
reader.envvar_prefix(Section.reporting),
reader.use(values.get("reporting")),
):
r_type = reader.str(Fragment.type)
reporting_model = ReportingConfig(
type=ReportingType(r_type) if r_type else defs.REPORTING_TYPE,
connection_string=reader.str(Fragment.conn_string),
storage_account_blob_url=reader.str(Fragment.storage_account_blob_url),
container_name=reader.str(Fragment.container_name),
base_dir=reader.str(Fragment.base_dir) or defs.REPORTING_BASE_DIR,
)
with reader.envvar_prefix(Section.storage), reader.use(values.get("storage")):
s_type = reader.str(Fragment.type)
storage_model = StorageConfig(
type=StorageType(s_type) if s_type else defs.STORAGE_TYPE,
connection_string=reader.str(Fragment.conn_string),
storage_account_blob_url=reader.str(Fragment.storage_account_blob_url),
container_name=reader.str(Fragment.container_name),
base_dir=reader.str(Fragment.base_dir) or defs.STORAGE_BASE_DIR,
cosmosdb_account_url=reader.str(Fragment.cosmosdb_account_url),
)
with (
reader.envvar_prefix(Section.update_index_storage),
reader.use(values.get("update_index_storage")),
):
s_type = reader.str(Fragment.type)
if s_type:
update_index_storage_model = StorageConfig(
type=StorageType(s_type) if s_type else defs.STORAGE_TYPE,
connection_string=reader.str(Fragment.conn_string),
storage_account_blob_url=reader.str(
Fragment.storage_account_blob_url
),
container_name=reader.str(Fragment.container_name),
base_dir=reader.str(Fragment.base_dir)
or defs.UPDATE_STORAGE_BASE_DIR,
)
else:
update_index_storage_model = None
with reader.envvar_prefix(Section.chunk), reader.use(values.get("chunks")):
group_by_columns = reader.list("group_by_columns", "BY_COLUMNS")
if group_by_columns is None:
group_by_columns = defs.CHUNK_GROUP_BY_COLUMNS
encoding_model = (
reader.str(Fragment.encoding_model) or global_encoding_model
)
strategy = reader.str("strategy")
chunks_model = ChunkingConfig(
size=reader.int("size") or defs.CHUNK_SIZE,
overlap=reader.int("overlap") or defs.CHUNK_OVERLAP,
group_by_columns=group_by_columns,
encoding_model=encoding_model,
strategy=ChunkStrategyType(strategy)
if strategy
else ChunkStrategyType.tokens,
)
with (
reader.envvar_prefix(Section.snapshot),
reader.use(values.get("snapshots")),
):
snapshots_model = SnapshotsConfig(
graphml=reader.bool("graphml") or defs.SNAPSHOTS_GRAPHML,
embeddings=reader.bool("embeddings") or defs.SNAPSHOTS_EMBEDDINGS,
transient=reader.bool("transient") or defs.SNAPSHOTS_TRANSIENT,
)
with reader.envvar_prefix(Section.umap), reader.use(values.get("umap")):
umap_model = UmapConfig(
enabled=reader.bool(Fragment.enabled) or defs.UMAP_ENABLED,
)
entity_extraction_config = values.get("entity_extraction") or {}
with (
reader.envvar_prefix(Section.entity_extraction),
reader.use(entity_extraction_config),
):
max_gleanings = reader.int(Fragment.max_gleanings)
max_gleanings = (
max_gleanings
if max_gleanings is not None
else defs.ENTITY_EXTRACTION_MAX_GLEANINGS
)
encoding_model = (
reader.str(Fragment.encoding_model) or global_encoding_model
)
entity_extraction_model = EntityExtractionConfig(
llm=hydrate_llm_params(entity_extraction_config, llm_model),
parallelization=hydrate_parallelization_params(
entity_extraction_config, llm_parallelization_model
),
async_mode=hydrate_async_type(entity_extraction_config, async_mode),
entity_types=reader.list("entity_types")
or defs.ENTITY_EXTRACTION_ENTITY_TYPES,
max_gleanings=max_gleanings,
prompt=reader.str("prompt", Fragment.prompt_file),
strategy=entity_extraction_config.get("strategy"),
encoding_model=encoding_model,
)
claim_extraction_config = values.get("claim_extraction") or {}
with (
reader.envvar_prefix(Section.claim_extraction),
reader.use(claim_extraction_config),
):
max_gleanings = reader.int(Fragment.max_gleanings)
max_gleanings = (
max_gleanings if max_gleanings is not None else defs.CLAIM_MAX_GLEANINGS
)
encoding_model = (
reader.str(Fragment.encoding_model) or global_encoding_model
)
claim_extraction_model = ClaimExtractionConfig(
enabled=reader.bool(Fragment.enabled) or defs.CLAIM_EXTRACTION_ENABLED,
llm=hydrate_llm_params(claim_extraction_config, llm_model),
parallelization=hydrate_parallelization_params(
claim_extraction_config, llm_parallelization_model
),
async_mode=hydrate_async_type(claim_extraction_config, async_mode),
description=reader.str("description") or defs.CLAIM_DESCRIPTION,
prompt=reader.str("prompt", Fragment.prompt_file),
max_gleanings=max_gleanings,
encoding_model=encoding_model,
)
community_report_config = values.get("community_reports") or {}
with (
reader.envvar_prefix(Section.community_reports),
reader.use(community_report_config),
):
community_reports_model = CommunityReportsConfig(
llm=hydrate_llm_params(community_report_config, llm_model),
parallelization=hydrate_parallelization_params(
community_report_config, llm_parallelization_model
),
async_mode=hydrate_async_type(community_report_config, async_mode),
prompt=reader.str("prompt", Fragment.prompt_file),
max_length=reader.int(Fragment.max_length)
or defs.COMMUNITY_REPORT_MAX_LENGTH,
max_input_length=reader.int("max_input_length")
or defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH,
)
summarize_description_config = values.get("summarize_descriptions") or {}
with (
reader.envvar_prefix(Section.summarize_descriptions),
reader.use(values.get("summarize_descriptions")),
):
summarize_descriptions_model = SummarizeDescriptionsConfig(
llm=hydrate_llm_params(summarize_description_config, llm_model),
parallelization=hydrate_parallelization_params(
summarize_description_config, llm_parallelization_model
),
async_mode=hydrate_async_type(summarize_description_config, async_mode),
prompt=reader.str("prompt", Fragment.prompt_file),
max_length=reader.int(Fragment.max_length)
or defs.SUMMARIZE_DESCRIPTIONS_MAX_LENGTH,
)
with reader.use(values.get("cluster_graph")):
use_lcc = reader.bool("use_lcc")
seed = reader.int("seed")
cluster_graph_model = ClusterGraphConfig(
max_cluster_size=reader.int("max_cluster_size")
or defs.MAX_CLUSTER_SIZE,
use_lcc=use_lcc if use_lcc is not None else defs.USE_LCC,
seed=seed if seed is not None else defs.CLUSTER_GRAPH_SEED,
)
with (
reader.use(values.get("local_search")),
reader.envvar_prefix(Section.local_search),
):
local_search_model = LocalSearchConfig(
prompt=reader.str("prompt") or None,
text_unit_prop=reader.float("text_unit_prop")
or defs.LOCAL_SEARCH_TEXT_UNIT_PROP,
community_prop=reader.float("community_prop")
or defs.LOCAL_SEARCH_COMMUNITY_PROP,
conversation_history_max_turns=reader.int(
"conversation_history_max_turns"
)
or defs.LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS,
top_k_entities=reader.int("top_k_entities")
or defs.LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES,
top_k_relationships=reader.int("top_k_relationships")
or defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS,
temperature=reader.float("llm_temperature")
or defs.LOCAL_SEARCH_LLM_TEMPERATURE,
top_p=reader.float("llm_top_p") or defs.LOCAL_SEARCH_LLM_TOP_P,
n=reader.int("llm_n") or defs.LOCAL_SEARCH_LLM_N,
max_tokens=reader.int(Fragment.max_tokens)
or defs.LOCAL_SEARCH_MAX_TOKENS,
llm_max_tokens=reader.int("llm_max_tokens")
or defs.LOCAL_SEARCH_LLM_MAX_TOKENS,
)
with (
reader.use(values.get("global_search")),
reader.envvar_prefix(Section.global_search),
):
global_search_model = GlobalSearchConfig(
map_prompt=reader.str("map_prompt") or None,
reduce_prompt=reader.str("reduce_prompt") or None,
knowledge_prompt=reader.str("knowledge_prompt") or None,
temperature=reader.float("llm_temperature")
or defs.GLOBAL_SEARCH_LLM_TEMPERATURE,
top_p=reader.float("llm_top_p") or defs.GLOBAL_SEARCH_LLM_TOP_P,
n=reader.int("llm_n") or defs.GLOBAL_SEARCH_LLM_N,
max_tokens=reader.int(Fragment.max_tokens)
or defs.GLOBAL_SEARCH_MAX_TOKENS,
data_max_tokens=reader.int("data_max_tokens")
or defs.GLOBAL_SEARCH_DATA_MAX_TOKENS,
map_max_tokens=reader.int("map_max_tokens")
or defs.GLOBAL_SEARCH_MAP_MAX_TOKENS,
reduce_max_tokens=reader.int("reduce_max_tokens")
or defs.GLOBAL_SEARCH_REDUCE_MAX_TOKENS,
concurrency=reader.int("concurrency") or defs.GLOBAL_SEARCH_CONCURRENCY,
)
with (
reader.use(values.get("drift_search")),
reader.envvar_prefix(Section.drift_search),
):
drift_search_model = DRIFTSearchConfig(
prompt=reader.str("prompt") or None,
reduce_prompt=reader.str("reduce_prompt") or None,
temperature=reader.float("llm_temperature")
or defs.DRIFT_SEARCH_LLM_TEMPERATURE,
top_p=reader.float("llm_top_p") or defs.DRIFT_SEARCH_LLM_TOP_P,
n=reader.int("llm_n") or defs.DRIFT_SEARCH_LLM_N,
max_tokens=reader.int(Fragment.max_tokens)
or defs.DRIFT_SEARCH_MAX_TOKENS,
data_max_tokens=reader.int("data_max_tokens")
or defs.DRIFT_SEARCH_DATA_MAX_TOKENS,
reduce_max_tokens=reader.int("reduce_max_tokens")
or defs.DRIFT_SEARCH_REDUCE_MAX_TOKENS,
reduce_temperature=reader.float("reduce_temperature")
or defs.DRIFT_SEARCH_REDUCE_LLM_TEMPERATURE,
concurrency=reader.int("concurrency") or defs.DRIFT_SEARCH_CONCURRENCY,
drift_k_followups=reader.int("drift_k_followups")
or defs.DRIFT_SEARCH_K_FOLLOW_UPS,
primer_folds=reader.int("primer_folds")
or defs.DRIFT_SEARCH_PRIMER_FOLDS,
primer_llm_max_tokens=reader.int("primer_llm_max_tokens")
or defs.DRIFT_SEARCH_PRIMER_MAX_TOKENS,
n_depth=reader.int("n_depth") or defs.DRIFT_N_DEPTH,
local_search_text_unit_prop=reader.float("local_search_text_unit_prop")
or defs.DRIFT_LOCAL_SEARCH_TEXT_UNIT_PROP,
local_search_community_prop=reader.float("local_search_community_prop")
or defs.DRIFT_LOCAL_SEARCH_COMMUNITY_PROP,
local_search_top_k_mapped_entities=reader.int(
"local_search_top_k_mapped_entities"
)
or defs.DRIFT_LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES,
local_search_top_k_relationships=reader.int(
"local_search_top_k_relationships"
)
or defs.DRIFT_LOCAL_SEARCH_TOP_K_RELATIONSHIPS,
local_search_max_data_tokens=reader.int("local_search_max_data_tokens")
or defs.DRIFT_LOCAL_SEARCH_MAX_TOKENS,
local_search_temperature=reader.float("local_search_temperature")
or defs.DRIFT_LOCAL_SEARCH_LLM_TEMPERATURE,
local_search_top_p=reader.float("local_search_top_p")
or defs.DRIFT_LOCAL_SEARCH_LLM_TOP_P,
local_search_n=reader.int("local_search_n")
or defs.DRIFT_LOCAL_SEARCH_LLM_N,
local_search_llm_max_gen_tokens=reader.int(
"local_search_llm_max_gen_tokens"
)
or defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS,
)
with (
reader.use(values.get("basic_search")),
reader.envvar_prefix(Section.basic_search),
):
basic_search_model = BasicSearchConfig(
prompt=reader.str("prompt") or None,
text_unit_prop=reader.float("text_unit_prop")
or defs.BASIC_SEARCH_TEXT_UNIT_PROP,
conversation_history_max_turns=reader.int(
"conversation_history_max_turns"
)
or defs.BASIC_SEARCH_CONVERSATION_HISTORY_MAX_TURNS,
temperature=reader.float("llm_temperature")
or defs.BASIC_SEARCH_LLM_TEMPERATURE,
top_p=reader.float("llm_top_p") or defs.BASIC_SEARCH_LLM_TOP_P,
n=reader.int("llm_n") or defs.BASIC_SEARCH_LLM_N,
max_tokens=reader.int(Fragment.max_tokens)
or defs.BASIC_SEARCH_MAX_TOKENS,
llm_max_tokens=reader.int("llm_max_tokens")
or defs.BASIC_SEARCH_LLM_MAX_TOKENS,
)
skip_workflows = reader.list("skip_workflows") or []
return GraphRagConfig(
root_dir=root_dir,
llm=llm_model,
parallelization=llm_parallelization_model,
async_mode=async_mode,
embeddings=embeddings_model,
embed_graph=embed_graph_model,
reporting=reporting_model,
storage=storage_model,
update_index_storage=update_index_storage_model,
cache=cache_model,
input=input_model,
chunks=chunks_model,
snapshots=snapshots_model,
entity_extraction=entity_extraction_model,
claim_extraction=claim_extraction_model,
community_reports=community_reports_model,
summarize_descriptions=summarize_descriptions_model,
umap=umap_model,
cluster_graph=cluster_graph_model,
encoding_model=global_encoding_model,
skip_workflows=skip_workflows,
local_search=local_search_model,
global_search=global_search_model,
drift_search=drift_search_model,
basic_search=basic_search_model,
)
class Fragment(str, Enum):
"""Configuration Fragments."""
api_base = "API_BASE"
api_key = "API_KEY"
api_version = "API_VERSION"
api_organization = "API_ORGANIZATION"
api_proxy = "API_PROXY"
async_mode = "ASYNC_MODE"
audience = "AUDIENCE"
base_dir = "BASE_DIR"
concurrent_requests = "CONCURRENT_REQUESTS"
conn_string = "CONNECTION_STRING"
container_name = "CONTAINER_NAME"
cosmosdb_account_url = "COSMOSDB_ACCOUNT_URL"
deployment_name = "DEPLOYMENT_NAME"
description = "DESCRIPTION"
enabled = "ENABLED"
encoding = "ENCODING"
encoding_model = "ENCODING_MODEL"
file_type = "FILE_TYPE"
max_gleanings = "MAX_GLEANINGS"
max_length = "MAX_LENGTH"
max_retries = "MAX_RETRIES"
max_retry_wait = "MAX_RETRY_WAIT"
max_tokens = "MAX_TOKENS"
temperature = "TEMPERATURE"
top_p = "TOP_P"
n = "N"
model = "MODEL"
model_supports_json = "MODEL_SUPPORTS_JSON"
prompt_file = "PROMPT_FILE"
request_timeout = "REQUEST_TIMEOUT"
rpm = "REQUESTS_PER_MINUTE"
sleep_recommendation = "SLEEP_ON_RATE_LIMIT_RECOMMENDATION"
storage_account_blob_url = "STORAGE_ACCOUNT_BLOB_URL"
thread_count = "THREAD_COUNT"
thread_stagger = "THREAD_STAGGER"
tpm = "TOKENS_PER_MINUTE"
type = "TYPE"
class Section(str, Enum):
"""Configuration Sections."""
base = "BASE"
cache = "CACHE"
chunk = "CHUNK"
claim_extraction = "CLAIM_EXTRACTION"
community_reports = "COMMUNITY_REPORTS"
embedding = "EMBEDDING"
entity_extraction = "ENTITY_EXTRACTION"
graphrag = "GRAPHRAG"
input = "INPUT"
llm = "LLM"
node2vec = "NODE2VEC"
reporting = "REPORTING"
snapshot = "SNAPSHOT"
storage = "STORAGE"
summarize_descriptions = "SUMMARIZE_DESCRIPTIONS"
umap = "UMAP"
update_index_storage = "UPDATE_INDEX_STORAGE"
local_search = "LOCAL_SEARCH"
global_search = "GLOBAL_SEARCH"
drift_search = "DRIFT_SEARCH"
basic_search = "BASIC_SEARCH"
def _is_azure(llm_type: LLMType | None) -> bool:
return (
llm_type == LLMType.AzureOpenAIChat or llm_type == LLMType.AzureOpenAIEmbedding
)
def _make_env(root_dir: str) -> Env:
read_dotenv(root_dir)
env = Env(expand_vars=True)
env.read_env()
return env
def _token_replace(data: dict):
"""Replace env-var tokens in a dictionary object."""
for key, value in data.items():
if isinstance(value, dict):
_token_replace(value)
elif isinstance(value, str):
data[key] = os.path.expandvars(value)
if root_dir:
root_path = Path(root_dir).resolve()
values["root_dir"] = str(root_path)
return GraphRagConfig(**values)

View File

@ -8,15 +8,18 @@ from pathlib import Path
from graphrag.config.enums import (
AsyncType,
CacheType,
ChunkStrategyType,
InputFileType,
InputType,
LLMType,
OutputType,
ReportingType,
StorageType,
TextEmbeddingTarget,
)
from graphrag.vector_stores.factory import VectorStoreType
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model"
ASYNC_MODE = AsyncType.Threaded
ENCODING_MODEL = "cl100k_base"
AZURE_AUDIENCE = "https://cognitiveservices.azure.com/.default"
@ -47,24 +50,29 @@ EMBEDDING_MODEL = "text-embedding-3-small"
EMBEDDING_BATCH_SIZE = 16
EMBEDDING_BATCH_MAX_TOKENS = 8191
EMBEDDING_TARGET = TextEmbeddingTarget.required
EMBEDDING_MODEL_ID = DEFAULT_EMBEDDING_MODEL_ID
CACHE_TYPE = CacheType.file
CACHE_BASE_DIR = "cache"
CHUNK_SIZE = 1200
CHUNK_OVERLAP = 100
CHUNK_GROUP_BY_COLUMNS = ["id"]
CHUNK_STRATEGY = ChunkStrategyType.tokens
CLAIM_DESCRIPTION = (
"Any claims or facts that could be relevant to information discovery."
)
CLAIM_MAX_GLEANINGS = 1
CLAIM_EXTRACTION_ENABLED = False
CLAIM_EXTRACTION_MODEL_ID = DEFAULT_CHAT_MODEL_ID
MAX_CLUSTER_SIZE = 10
USE_LCC = True
CLUSTER_GRAPH_SEED = 0xDEADBEEF
COMMUNITY_REPORT_MAX_LENGTH = 2000
COMMUNITY_REPORT_MAX_INPUT_LENGTH = 8000
COMMUNITY_REPORT_MODEL_ID = DEFAULT_CHAT_MODEL_ID
ENTITY_EXTRACTION_ENTITY_TYPES = ["organization", "person", "geo", "event"]
ENTITY_EXTRACTION_MAX_GLEANINGS = 1
ENTITY_EXTRACTION_MODEL_ID = DEFAULT_CHAT_MODEL_ID
INPUT_FILE_TYPE = InputFileType.text
INPUT_TYPE = InputType.file
INPUT_BASE_DIR = "input"
@ -86,25 +94,18 @@ REPORTING_BASE_DIR = "logs"
SNAPSHOTS_GRAPHML = False
SNAPSHOTS_EMBEDDINGS = False
SNAPSHOTS_TRANSIENT = False
STORAGE_BASE_DIR = "output"
STORAGE_TYPE = StorageType.file
OUTPUT_BASE_DIR = "output"
OUTPUT_TYPE = OutputType.file
SUMMARIZE_DESCRIPTIONS_MAX_LENGTH = 500
SUMMARIZE_MODEL_ID = DEFAULT_CHAT_MODEL_ID
UMAP_ENABLED = False
UPDATE_STORAGE_BASE_DIR = "update_output"
UPDATE_OUTPUT_BASE_DIR = "update_output"
VECTOR_STORE = f"""
type: {VectorStoreType.LanceDB.value} # one of [lancedb, azure_ai_search, cosmosdb]
db_uri: '{(Path(STORAGE_BASE_DIR) / "lancedb")!s}'
collection_name: default
overwrite: true\
"""
VECTOR_STORE_DICT = {
"type": VectorStoreType.LanceDB.value,
"db_uri": str(Path(STORAGE_BASE_DIR) / "lancedb"),
"collection_name": "default",
"overwrite": True,
}
VECTOR_STORE_TYPE = VectorStoreType.LanceDB.value
VECTOR_STORE_DB_URI = str(Path(OUTPUT_BASE_DIR) / "lancedb")
VECTOR_STORE_CONTAINER_NAME = "default"
VECTOR_STORE_OVERWRITE = True
# Local Search
LOCAL_SEARCH_TEXT_UNIT_PROP = 0.5

View File

@ -5,7 +5,6 @@
from graphrag.config.enums import TextEmbeddingTarget
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
entity_title_embedding = "entity.title"
entity_description_embedding = "entity.description"
@ -34,12 +33,14 @@ required_embeddings: set[str] = {
def get_embedded_fields(settings: GraphRagConfig) -> set[str]:
"""Get the fields to embed based on the enum or specifically skipped embeddings."""
"""Get the fields to embed based on the enum or specifically selected embeddings."""
match settings.embeddings.target:
case TextEmbeddingTarget.all:
return all_embeddings.difference(settings.embeddings.skip)
return all_embeddings
case TextEmbeddingTarget.required:
return required_embeddings
case TextEmbeddingTarget.selected:
return set(settings.embeddings.names)
case TextEmbeddingTarget.none:
return set()
case _:
@ -48,24 +49,53 @@ def get_embedded_fields(settings: GraphRagConfig) -> set[str]:
def get_embedding_settings(
settings: TextEmbeddingConfig,
settings: GraphRagConfig,
vector_store_params: dict | None = None,
) -> dict:
"""Transform GraphRAG config into settings for workflows."""
# TEMP
vector_store_settings = settings.vector_store
embeddings_llm_settings = settings.get_language_model_config(
settings.embeddings.model_id
)
vector_store_settings = settings.vector_store.model_dump()
if vector_store_settings is None:
return {"strategy": settings.resolved_strategy()}
return {
"strategy": settings.embeddings.resolved_strategy(embeddings_llm_settings)
}
#
# If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding.
# settings.vector_store.base contains connection information, or may be undefined
# settings.vector_store.<vector_name> contains the specific settings for this embedding
#
strategy = settings.resolved_strategy() # get the default strategy
strategy = settings.embeddings.resolved_strategy(
embeddings_llm_settings
) # get the default strategy
strategy.update({
"vector_store": {**(vector_store_params or {}), **vector_store_settings}
"vector_store": {
**(vector_store_params or {}),
**(vector_store_settings),
}
}) # update the default strategy with the vector store settings
# This ensures the vector store config is part of the strategy and not the global config
return {
"strategy": strategy,
}
def create_collection_name(
container_name: str, embedding_name: str, validate: bool = True
) -> str:
"""
Create a collection name for the embedding store.
Within any given vector store, we can have multiple sets of embeddings organized into projects.
The `container` param is used for this partitioning, and is added as a prefix to the collection name for differentiation.
The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings
Note that we use dot notation in our names, but many vector stores do not support this - so we convert to dashes.
"""
if validate and embedding_name not in all_embeddings:
msg = f"Invalid embedding name: {embedding_name}"
raise KeyError(msg)
return f"{container_name}-{embedding_name}".replace(".", "-")

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'PipelineCacheConfig', 'PipelineFileCacheConfig' and 'PipelineMemoryCacheConfig' models."""
"""A module containing config enums."""
from __future__ import annotations
@ -53,17 +53,17 @@ class InputType(str, Enum):
return f'"{self.value}"'
class StorageType(str, Enum):
"""The storage type for the pipeline."""
class OutputType(str, Enum):
"""The output type for the pipeline."""
file = "file"
"""The file storage type."""
"""The file output type."""
memory = "memory"
"""The memory storage type."""
"""The memory output type."""
blob = "blob"
"""The blob storage type."""
"""The blob output type."""
cosmosdb = "cosmosdb"
"""The cosmosdb storage type"""
"""The cosmosdb output type"""
def __repr__(self):
"""Get a string representation."""
@ -90,6 +90,7 @@ class TextEmbeddingTarget(str, Enum):
all = "all"
required = "required"
selected = "selected"
none = "none"
def __repr__(self):
@ -116,8 +117,26 @@ class LLMType(str, Enum):
return f'"{self.value}"'
class AzureAuthType(str, Enum):
"""AzureAuthType enum class definition."""
APIKey = "api_key"
ManagedIdentity = "managed_identity"
class AsyncType(str, Enum):
"""Enum for the type of async to use."""
AsyncIO = "asyncio"
Threaded = "threaded"
class ChunkStrategyType(str, Enum):
"""ChunkStrategy class definition."""
tokens = "tokens"
sentence = "sentence"
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'

View File

@ -6,35 +6,54 @@
class ApiKeyMissingError(ValueError):
"""LLM Key missing error."""
def __init__(self, embedding: bool = False) -> None:
def __init__(self, llm_type: str, azure_auth_type: str | None = None) -> None:
"""Init method definition."""
api_type = "Embedding" if embedding else "Completion"
api_key = "GRAPHRAG_EMBEDDING_API_KEY" if embedding else "GRAPHRAG_LLM_API_KEY"
msg = f"API Key is required for {api_type} API. Please set either the OPENAI_API_KEY, GRAPHRAG_API_KEY or {api_key} environment variable."
msg = f"API Key is required for {llm_type}"
if azure_auth_type:
msg += f" when using {azure_auth_type} authentication"
msg += ". Please rerun `graphrag init` and set the API_KEY."
super().__init__(msg)
class AzureApiBaseMissingError(ValueError):
"""Azure API Base missing error."""
def __init__(self, embedding: bool = False) -> None:
def __init__(self, llm_type: str) -> None:
"""Init method definition."""
api_type = "Embedding" if embedding else "Completion"
api_base = "GRAPHRAG_EMBEDDING_API_BASE" if embedding else "GRAPHRAG_API_BASE"
msg = f"API Base is required for {api_type} API. Please set either the OPENAI_API_BASE, GRAPHRAG_API_BASE or {api_base} environment variable."
msg = f"API Base is required for {llm_type}. Please rerun `graphrag init` and set the api_base."
super().__init__(msg)
class AzureApiVersionMissingError(ValueError):
"""Azure API version missing error."""
def __init__(self, llm_type: str) -> None:
"""Init method definition."""
msg = f"API Version is required for {llm_type}. Please rerun `graphrag init` and set the api_version."
super().__init__(msg)
class AzureDeploymentNameMissingError(ValueError):
"""Azure Deployment Name missing error."""
def __init__(self, embedding: bool = False) -> None:
def __init__(self, llm_type: str) -> None:
"""Init method definition."""
msg = f"Deployment name is required for {llm_type}. Please rerun `graphrag init` set the deployment_name."
super().__init__(msg)
class LanguageModelConfigMissingError(ValueError):
"""Missing model configuration error."""
def __init__(self, key: str = "") -> None:
"""Init method definition."""
msg = f'A {key} model configuration is required. Please rerun `graphrag init` and set models["{key}"] in settings.yaml.'
super().__init__(msg)
class ConflictingSettingsError(ValueError):
"""Missing model configuration error."""
def __init__(self, msg: str) -> None:
"""Init method definition."""
api_type = "Embedding" if embedding else "Completion"
api_base = (
"GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME"
if embedding
else "GRAPHRAG_LLM_DEPLOYMENT_NAME"
)
msg = f"Deployment Name is required for {api_type} API. Please set either the OPENAI_DEPLOYMENT_NAME, GRAPHRAG_LLM_DEPLOYMENT_NAME or {api_base} environment variable."
super().__init__(msg)

View File

@ -12,38 +12,42 @@ INIT_YAML = f"""\
### LLM settings ###
## There are a number of settings to tune the threading and token limits for LLM calls - check the docs.
encoding_model: {defs.ENCODING_MODEL} # this needs to be matched to your model!
llm:
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
type: {defs.LLM_TYPE.value} # or azure_openai_chat
model: {defs.LLM_MODEL}
model_supports_json: true # recommended if this is available for your model.
# audience: "https://cognitiveservices.azure.com/.default"
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-02-15-preview
# organization: <organization_id>
# deployment_name: <azure_model_deployment_name>
parallelization:
stagger: {defs.PARALLELIZATION_STAGGER}
# num_threads: {defs.PARALLELIZATION_NUM_THREADS}
async_mode: {defs.ASYNC_MODE.value} # or asyncio
embeddings:
async_mode: {defs.ASYNC_MODE.value} # or asyncio
vector_store: {defs.VECTOR_STORE}
llm:
models:
{defs.DEFAULT_CHAT_MODEL_ID}:
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
type: {defs.LLM_TYPE.value} # or azure_openai_chat
model: {defs.LLM_MODEL}
model_supports_json: true # recommended if this is available for your model.
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
parallelization_stagger: {defs.PARALLELIZATION_STAGGER}
async_mode: {defs.ASYNC_MODE.value} # or asyncio
# audience: "https://cognitiveservices.azure.com/.default"
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-02-15-preview
# organization: <organization_id>
# deployment_name: <azure_model_deployment_name>
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
api_key: ${{GRAPHRAG_API_KEY}}
type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding
model: {defs.EMBEDDING_MODEL}
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
parallelization_stagger: {defs.PARALLELIZATION_STAGGER}
async_mode: {defs.ASYNC_MODE.value} # or asyncio
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-02-15-preview
# audience: "https://cognitiveservices.azure.com/.default"
# organization: <organization_id>
# deployment_name: <azure_model_deployment_name>
vector_store:
type: {defs.VECTOR_STORE_TYPE}
db_uri: {defs.VECTOR_STORE_DB_URI}
container_name: {defs.VECTOR_STORE_CONTAINER_NAME}
overwrite: {defs.VECTOR_STORE_OVERWRITE}
embeddings:
model_id: {defs.DEFAULT_EMBEDDING_MODEL_ID}
### Input settings ###
input:
@ -51,14 +55,14 @@ input:
file_type: {defs.INPUT_FILE_TYPE.value} # or csv
base_dir: "{defs.INPUT_BASE_DIR}"
file_encoding: {defs.INPUT_FILE_ENCODING}
file_pattern: ".*\\\\.txt$"
file_pattern: ".*\\\\.txt$$"
chunks:
size: {defs.CHUNK_SIZE}
overlap: {defs.CHUNK_OVERLAP}
group_by_columns: [{",".join(defs.CHUNK_GROUP_BY_COLUMNS)}]
### Storage settings ###
### Output settings ###
## If blob storage is specified in the following four sections,
## connection_string and container_name must be provided
@ -70,36 +74,38 @@ reporting:
type: {defs.REPORTING_TYPE.value} # or console, blob
base_dir: "{defs.REPORTING_BASE_DIR}"
storage:
type: {defs.STORAGE_TYPE.value} # one of [blob, cosmosdb, file]
base_dir: "{defs.STORAGE_BASE_DIR}"
output:
type: {defs.OUTPUT_TYPE.value} # one of [blob, cosmosdb, file]
base_dir: "{defs.OUTPUT_BASE_DIR}"
## only turn this on if running `graphrag index` with custom settings
## we normally use `graphrag update` with the defaults
update_index_storage:
# type: {defs.STORAGE_TYPE.value} # or blob
# base_dir: "{defs.UPDATE_STORAGE_BASE_DIR}"
update_index_output:
# type: {defs.OUTPUT_TYPE.value} # or blob
# base_dir: "{defs.UPDATE_OUTPUT_BASE_DIR}"
### Workflow settings ###
skip_workflows: []
entity_extraction:
model_id: {defs.ENTITY_EXTRACTION_MODEL_ID}
prompt: "prompts/entity_extraction.txt"
entity_types: [{",".join(defs.ENTITY_EXTRACTION_ENTITY_TYPES)}]
max_gleanings: {defs.ENTITY_EXTRACTION_MAX_GLEANINGS}
summarize_descriptions:
model_id: {defs.SUMMARIZE_MODEL_ID}
prompt: "prompts/summarize_descriptions.txt"
max_length: {defs.SUMMARIZE_DESCRIPTIONS_MAX_LENGTH}
claim_extraction:
enabled: false
model_id: {defs.CLAIM_EXTRACTION_MODEL_ID}
prompt: "prompts/claim_extraction.txt"
description: "{defs.CLAIM_DESCRIPTION}"
max_gleanings: {defs.CLAIM_MAX_GLEANINGS}
community_reports:
model_id: {defs.COMMUNITY_REPORT_MODEL_ID}
prompt: "prompts/community_report.txt"
max_length: {defs.COMMUNITY_REPORT_MAX_LENGTH}
max_input_length: {defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH}

View File

@ -3,23 +3,86 @@
"""Default method for loading config."""
import json
import os
from pathlib import Path
from string import Template
from typing import Any
import yaml
from dotenv import load_dotenv
from graphrag.config.config_file_loader import (
load_config_from_file,
search_for_config_in_root_dir,
)
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.models.graph_rag_config import GraphRagConfig
_default_config_files = ["settings.yaml", "settings.yml", "settings.json"]
def load_config(
root_dir: Path,
config_filepath: Path | None = None,
) -> GraphRagConfig:
"""Load configuration from a file or create a default configuration.
If a config file is not found the default configuration is created.
def _search_for_config_in_root_dir(root: str | Path) -> Path | None:
"""Resolve the config path from the given root directory.
Parameters
----------
root : str | Path
The path to the root directory containing the config file.
Searches for a default config file (settings.{yaml,yml,json}).
Returns
-------
Path | None
returns a Path if there is a config in the root directory
Otherwise returns None.
"""
root = Path(root)
if not root.is_dir():
msg = f"Invalid config path: {root} is not a directory"
raise FileNotFoundError(msg)
for file in _default_config_files:
if (root / file).is_file():
return root / file
return None
def _parse_env_variables(text: str) -> str:
"""Parse environment variables in the configuration text.
Parameters
----------
text : str
The configuration text.
Returns
-------
str
The configuration text with environment variables parsed.
Raises
------
KeyError
If an environment variable is not found.
"""
return Template(text).substitute(os.environ)
def _load_dotenv(config_path: Path | str) -> None:
"""Load the .env file if it exists in the same directory as the config file.
Parameters
----------
config_path : Path | str
The path to the config file.
"""
config_path = Path(config_path)
dotenv_path = config_path.parent / ".env"
if dotenv_path.exists():
load_dotenv(dotenv_path)
def _get_config_path(root_dir: Path, config_filepath: Path | None) -> Path:
"""Get the configuration file path.
Parameters
----------
@ -27,23 +90,102 @@ def load_config(
The root directory of the project. Will search for the config file in this directory.
config_filepath : str | None
The path to the config file.
If None, searches for config file in root and
if not found creates a default configuration.
"""
root = root_dir.resolve()
If None, searches for config file in root.
# If user specified a config file path then it is required
Returns
-------
Path
The configuration file path.
"""
if config_filepath:
config_path = config_filepath.resolve()
if not config_path.exists():
msg = f"Specified Config file not found: {config_path}"
raise FileNotFoundError(msg)
else:
# resolve config filepath from the root directory if it exists
config_path = search_for_config_in_root_dir(root)
if config_path:
config = load_config_from_file(config_path)
else:
config = create_graphrag_config(root_dir=str(root))
config_path = _search_for_config_in_root_dir(root_dir)
return config
if not config_path:
msg = f"Config file not found in root directory: {root_dir}"
raise FileNotFoundError(msg)
return config_path
def _apply_overrides(data: dict[str, Any], overrides: dict[str, Any]) -> None:
"""Apply the overrides to the raw configuration."""
for key, value in overrides.items():
keys = key.split(".")
target = data
current_path = keys[0]
for k in keys[:-1]:
current_path += f".{k}"
target_obj = target.get(k, {})
if not isinstance(target_obj, dict):
msg = f"Cannot override non-dict value: data[{current_path}] is not a dict."
raise TypeError(msg)
target[k] = target_obj
target = target[k]
target[keys[-1]] = value
def _parse(file_extension: str, contents: str) -> dict[str, Any]:
"""Parse configuration."""
match file_extension:
case ".yaml" | ".yml":
return yaml.safe_load(contents)
case ".json":
return json.loads(contents)
case _:
msg = (
f"Unable to parse config. Unsupported file extension: {file_extension}"
)
raise ValueError(msg)
def load_config(
root_dir: Path,
config_filepath: Path | None = None,
cli_overrides: dict[str, Any] | None = None,
) -> GraphRagConfig:
"""Load configuration from a file.
Parameters
----------
root_dir : str | Path
The root directory of the project. Will search for the config file in this directory.
config_filepath : str | None
The path to the config file.
If None, searches for config file in root.
cli_overrides : dict[str, Any] | None
A flat dictionary of cli overrides.
Example: {'output.base_dir': 'override_value'}
Returns
-------
GraphRagConfig
The loaded configuration.
Raises
------
FileNotFoundError
If the config file is not found.
ValueError
If the config file extension is not supported.
TypeError
If applying cli overrides to the config fails.
KeyError
If config file references a non-existent environment variable.
ValidationError
If there are pydantic validation errors when instantiating the config.
"""
root = root_dir.resolve()
config_path = _get_config_path(root, config_filepath)
_load_dotenv(config_path)
config_extension = config_path.suffix
config_text = config_path.read_text(encoding="utf-8")
config_text = _parse_env_variables(config_text)
config_data = _parse(config_extension, config_text)
if cli_overrides:
_apply_overrides(config_data, cli_overrides)
return create_graphrag_config(config_data, root_dir=str(root))

View File

@ -22,15 +22,15 @@ class BasicSearchConfig(BaseModel):
description="The conversation history maximum turns.",
default=defs.BASIC_SEARCH_CONVERSATION_HISTORY_MAX_TURNS,
)
temperature: float | None = Field(
temperature: float = Field(
description="The temperature to use for token generation.",
default=defs.BASIC_SEARCH_LLM_TEMPERATURE,
)
top_p: float | None = Field(
top_p: float = Field(
description="The top-p value to use for token generation.",
default=defs.BASIC_SEARCH_LLM_TOP_P,
)
n: int | None = Field(
n: int = Field(
description="The number of completions to generate.",
default=defs.BASIC_SEARCH_LLM_N,
)

View File

@ -3,22 +3,10 @@
"""Parameterization settings for the default configuration."""
from enum import Enum
from pydantic import BaseModel, Field
import graphrag.config.defaults as defs
class ChunkStrategyType(str, Enum):
"""ChunkStrategy class definition."""
tokens = "tokens"
sentence = "sentence"
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'
from graphrag.config.enums import ChunkStrategyType
class ChunkingConfig(BaseModel):
@ -33,7 +21,7 @@ class ChunkingConfig(BaseModel):
default=defs.CHUNK_GROUP_BY_COLUMNS,
)
strategy: ChunkStrategyType = Field(
description="The chunking strategy to use.", default=ChunkStrategyType.tokens
description="The chunking strategy to use.", default=defs.CHUNK_STRATEGY
)
encoding_model: str = Field(
description="The encoding model to use.", default=defs.ENCODING_MODEL

View File

@ -5,17 +5,18 @@
from pathlib import Path
from pydantic import Field
from pydantic import BaseModel, Field
import graphrag.config.defaults as defs
from graphrag.config.models.llm_config import LLMConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
class ClaimExtractionConfig(LLMConfig):
class ClaimExtractionConfig(BaseModel):
"""Configuration section for claim extraction."""
enabled: bool = Field(
description="Whether claim extraction is enabled.",
default=defs.CLAIM_EXTRACTION_ENABLED,
)
prompt: str | None = Field(
description="The claim extraction prompt to use.", default=None
@ -34,18 +35,25 @@ class ClaimExtractionConfig(LLMConfig):
encoding_model: str | None = Field(
default=None, description="The encoding model to use."
)
model_id: str = Field(
description="The model ID to use for claim extraction.",
default=defs.CLAIM_EXTRACTION_MODEL_ID,
)
def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
def resolved_strategy(
self, root_dir: str, model_config: LanguageModelConfig
) -> dict:
"""Get the resolved claim extraction strategy."""
return self.strategy or {
"llm": self.llm.model_dump(),
**self.parallelization.model_dump(),
"extraction_prompt": (Path(root_dir) / self.prompt)
.read_bytes()
.decode(encoding="utf-8")
"llm": model_config.model_dump(),
"stagger": model_config.parallelization_stagger,
"num_threads": model_config.parallelization_num_threads,
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
encoding="utf-8"
)
if self.prompt
else None,
"claim_description": self.description,
"max_gleanings": self.max_gleanings,
"encoding_name": encoding_model or self.encoding_model,
"encoding_name": model_config.encoding_model,
}

View File

@ -18,7 +18,7 @@ class ClusterGraphConfig(BaseModel):
description="Whether to use the largest connected component.",
default=defs.USE_LCC,
)
seed: int | None = Field(
seed: int = Field(
description="The seed to use for the clustering.",
default=defs.CLUSTER_GRAPH_SEED,
)

View File

@ -5,13 +5,13 @@
from pathlib import Path
from pydantic import Field
from pydantic import BaseModel, Field
import graphrag.config.defaults as defs
from graphrag.config.models.llm_config import LLMConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
class CommunityReportsConfig(LLMConfig):
class CommunityReportsConfig(BaseModel):
"""Configuration section for community reports."""
prompt: str | None = Field(
@ -28,8 +28,14 @@ class CommunityReportsConfig(LLMConfig):
strategy: dict | None = Field(
description="The override strategy to use.", default=None
)
model_id: str = Field(
description="The model ID to use for community reports.",
default=defs.COMMUNITY_REPORT_MODEL_ID,
)
def resolved_strategy(self, root_dir) -> dict:
def resolved_strategy(
self, root_dir: str, model_config: LanguageModelConfig
) -> dict:
"""Get the resolved community report extraction strategy."""
from graphrag.index.operations.summarize_communities import (
CreateCommunityReportsStrategyType,
@ -37,11 +43,12 @@ class CommunityReportsConfig(LLMConfig):
return self.strategy or {
"type": CreateCommunityReportsStrategyType.graph_intelligence,
"llm": self.llm.model_dump(),
**self.parallelization.model_dump(),
"extraction_prompt": (Path(root_dir) / self.prompt)
.read_bytes()
.decode(encoding="utf-8")
"llm": model_config.model_dump(),
"stagger": model_config.parallelization_stagger,
"num_threads": model_config.parallelization_num_threads,
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
encoding="utf-8"
)
if self.prompt
else None,
"max_report_length": self.max_length,

View File

@ -5,13 +5,13 @@
from pathlib import Path
from pydantic import Field
from pydantic import BaseModel, Field
import graphrag.config.defaults as defs
from graphrag.config.models.llm_config import LLMConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
class EntityExtractionConfig(LLMConfig):
class EntityExtractionConfig(BaseModel):
"""Configuration section for entity extraction."""
prompt: str | None = Field(
@ -31,8 +31,14 @@ class EntityExtractionConfig(LLMConfig):
encoding_model: str | None = Field(
default=None, description="The encoding model to use."
)
model_id: str = Field(
description="The model ID to use for text embeddings.",
default=defs.ENTITY_EXTRACTION_MODEL_ID,
)
def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
def resolved_strategy(
self, root_dir: str, model_config: LanguageModelConfig
) -> dict:
"""Get the resolved entity extraction strategy."""
from graphrag.index.operations.extract_entities import (
ExtractEntityStrategyType,
@ -40,13 +46,14 @@ class EntityExtractionConfig(LLMConfig):
return self.strategy or {
"type": ExtractEntityStrategyType.graph_intelligence,
"llm": self.llm.model_dump(),
**self.parallelization.model_dump(),
"extraction_prompt": (Path(root_dir) / self.prompt)
.read_bytes()
.decode(encoding="utf-8")
"llm": model_config.model_dump(),
"stagger": model_config.parallelization_stagger,
"num_threads": model_config.parallelization_num_threads,
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
encoding="utf-8"
)
if self.prompt
else None,
"max_gleanings": self.max_gleanings,
"encoding_name": encoding_model or self.encoding_model,
"encoding_name": model_config.encoding_model,
}

View File

@ -20,15 +20,15 @@ class GlobalSearchConfig(BaseModel):
knowledge_prompt: str | None = Field(
description="The global search general prompt to use.", default=None
)
temperature: float | None = Field(
temperature: float = Field(
description="The temperature to use for token generation.",
default=defs.GLOBAL_SEARCH_LLM_TEMPERATURE,
)
top_p: float | None = Field(
top_p: float = Field(
description="The top-p value to use for token generation.",
default=defs.GLOBAL_SEARCH_LLM_TOP_P,
)
n: int | None = Field(
n: int = Field(
description="The number of completions to generate.",
default=defs.GLOBAL_SEARCH_LLM_N,
)

View File

@ -3,10 +3,13 @@
"""Parameterization settings for the default configuration."""
from pathlib import Path
from devtools import pformat
from pydantic import Field
from pydantic import BaseModel, Field, model_validator
import graphrag.config.defaults as defs
from graphrag.config.errors import LanguageModelConfigMissingError
from graphrag.config.models.basic_search_config import BasicSearchConfig
from graphrag.config.models.cache_config import CacheConfig
from graphrag.config.models.chunking_config import ChunkingConfig
@ -18,19 +21,21 @@ from graphrag.config.models.embed_graph_config import EmbedGraphConfig
from graphrag.config.models.entity_extraction_config import EntityExtractionConfig
from graphrag.config.models.global_search_config import GlobalSearchConfig
from graphrag.config.models.input_config import InputConfig
from graphrag.config.models.llm_config import LLMConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.config.models.local_search_config import LocalSearchConfig
from graphrag.config.models.output_config import OutputConfig
from graphrag.config.models.reporting_config import ReportingConfig
from graphrag.config.models.snapshots_config import SnapshotsConfig
from graphrag.config.models.storage_config import StorageConfig
from graphrag.config.models.summarize_descriptions_config import (
SummarizeDescriptionsConfig,
)
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
from graphrag.config.models.umap_config import UmapConfig
from graphrag.config.models.vector_store_config import VectorStoreConfig
from graphrag.vector_stores.factory import VectorStoreType
class GraphRagConfig(LLMConfig):
class GraphRagConfig(BaseModel):
"""Base class for the Default-Configuration parameterization settings."""
def __repr__(self) -> str:
@ -42,24 +47,91 @@ class GraphRagConfig(LLMConfig):
return self.model_dump_json(indent=4)
root_dir: str = Field(
description="The root directory for the configuration.", default="."
description="The root directory for the configuration.", default=""
)
def _validate_root_dir(self) -> None:
"""Validate the root directory."""
if self.root_dir.strip() == "":
self.root_dir = str(Path.cwd())
root_dir = Path(self.root_dir).resolve()
if not root_dir.is_dir():
msg = f"Invalid root directory: {self.root_dir} is not a directory."
raise FileNotFoundError(msg)
self.root_dir = str(root_dir)
models: dict[str, LanguageModelConfig] = Field(
description="Available language model configurations.",
default={},
)
def _validate_models(self) -> None:
"""Validate the models configuration.
Ensure both a default chat model and default embedding model
have been defined. Other models may also be defined but
defaults are required for the time being as places of the
code fallback to default model configs instead
of specifying a specific model.
TODO: Don't fallback to default models elsewhere in the code.
Forcing code to specify a model to use and allowing for any
names for model configurations.
"""
if defs.DEFAULT_CHAT_MODEL_ID not in self.models:
raise LanguageModelConfigMissingError(defs.DEFAULT_CHAT_MODEL_ID)
if defs.DEFAULT_EMBEDDING_MODEL_ID not in self.models:
raise LanguageModelConfigMissingError(defs.DEFAULT_EMBEDDING_MODEL_ID)
reporting: ReportingConfig = Field(
description="The reporting configuration.", default=ReportingConfig()
)
"""The reporting configuration."""
storage: StorageConfig = Field(
description="The storage configuration.", default=StorageConfig()
)
"""The storage configuration."""
def _validate_reporting_base_dir(self) -> None:
"""Validate the reporting base directory."""
if self.reporting.type == defs.ReportingType.file:
if self.reporting.base_dir.strip() == "":
msg = "Reporting base directory is required for file reporting. Please rerun `graphrag init` and set the reporting configuration."
raise ValueError(msg)
self.reporting.base_dir = str(
(Path(self.root_dir) / self.reporting.base_dir).resolve()
)
update_index_storage: StorageConfig | None = Field(
description="The storage configuration for the updated index.",
output: OutputConfig = Field(
description="The output configuration.", default=OutputConfig()
)
"""The output configuration."""
def _validate_output_base_dir(self) -> None:
"""Validate the output base directory."""
if self.output.type == defs.OutputType.file:
if self.output.base_dir.strip() == "":
msg = "output base directory is required for file output. Please rerun `graphrag init` and set the output configuration."
raise ValueError(msg)
self.output.base_dir = str(
(Path(self.root_dir) / self.output.base_dir).resolve()
)
update_index_output: OutputConfig | None = Field(
description="The output configuration for the updated index.",
default=None,
)
"""The storage configuration for the updated index."""
"""The output configuration for the updated index."""
def _validate_update_index_output_base_dir(self) -> None:
"""Validate the update index output base directory."""
if (
self.update_index_output
and self.update_index_output.type == defs.OutputType.file
):
if self.update_index_output.base_dir.strip() == "":
msg = "Update index output base directory is required for file output. Please rerun `graphrag init` and set the update index output configuration."
raise ValueError(msg)
self.update_index_output.base_dir = str(
(Path(self.root_dir) / self.update_index_output.base_dir).resolve()
)
cache: CacheConfig = Field(
description="The cache configuration.", default=CacheConfig()
@ -152,12 +224,52 @@ class GraphRagConfig(LLMConfig):
)
"""The basic search configuration."""
encoding_model: str = Field(
description="The encoding model to use.", default=defs.ENCODING_MODEL
vector_store: VectorStoreConfig = Field(
description="The vector store configuration.", default=VectorStoreConfig()
)
"""The encoding model to use."""
"""The vector store configuration."""
skip_workflows: list[str] = Field(
description="The workflows to skip, usually for testing reasons.", default=[]
)
"""The workflows to skip, usually for testing reasons."""
def _validate_vector_store_db_uri(self) -> None:
"""Validate the vector store configuration."""
if self.vector_store.type == VectorStoreType.LanceDB:
if not self.vector_store.db_uri or self.vector_store.db_uri.strip == "":
msg = "Vector store URI is required for LanceDB. Please rerun `graphrag init` and set the vector store configuration."
raise ValueError(msg)
self.vector_store.db_uri = str(
(Path(self.root_dir) / self.vector_store.db_uri).resolve()
)
def get_language_model_config(self, model_id: str) -> LanguageModelConfig:
"""Get a model configuration by ID.
Parameters
----------
model_id : str
The ID of the model to get. Should match an ID in the models list.
Returns
-------
LanguageModelConfig
The model configuration if found.
Raises
------
ValueError
If the model ID is not found in the configuration.
"""
if model_id not in self.models:
err_msg = f"Model ID {model_id} not found in configuration. Please rerun `graphrag init` and set the model configuration."
raise ValueError(err_msg)
return self.models[model_id]
@model_validator(mode="after")
def _validate_model(self):
"""Validate the model configuration."""
self._validate_root_dir()
self._validate_models()
self._validate_reporting_base_dir()
self._validate_output_base_dir()
self._validate_update_index_output_base_dir()
self._validate_vector_store_db_uri()
return self

View File

@ -30,7 +30,7 @@ class InputConfig(BaseModel):
container_name: str | None = Field(
description="The azure blob storage container name to use.", default=None
)
encoding: str | None = Field(
encoding: str = Field(
description="The input file encoding to use.",
default=defs.INPUT_FILE_ENCODING,
)

View File

@ -0,0 +1,239 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Language model configuration."""
import tiktoken
from pydantic import BaseModel, Field, model_validator
import graphrag.config.defaults as defs
from graphrag.config.enums import AsyncType, AzureAuthType, LLMType
from graphrag.config.errors import (
ApiKeyMissingError,
AzureApiBaseMissingError,
AzureApiVersionMissingError,
AzureDeploymentNameMissingError,
ConflictingSettingsError,
)
class LanguageModelConfig(BaseModel):
"""Language model configuration."""
api_key: str | None = Field(
description="The API key to use for the LLM service.",
default=None,
)
def _validate_api_key(self) -> None:
"""Validate the API key.
API Key is required when using OpenAI API
or when using Azure API with API Key authentication.
For the time being, this check is extra verbose for clarity.
It will also through an exception if an API Key is provided
when one is not expected such as the case of using Azure
Managed Identity.
Raises
------
ApiKeyMissingError
If the API key is missing and is required.
"""
if (
self.type == LLMType.OpenAIEmbedding
or self.type == LLMType.OpenAIChat
or self.azure_auth_type == AzureAuthType.APIKey
) and (self.api_key is None or self.api_key.strip() == ""):
raise ApiKeyMissingError(
self.type.value,
self.azure_auth_type.value if self.azure_auth_type else None,
)
if (self.azure_auth_type == AzureAuthType.ManagedIdentity) and (
self.api_key is not None and self.api_key.strip() != ""
):
msg = "API Key should not be provided when using Azure Managed Identity. Please rerun `graphrag init` and remove the api_key when using Azure Managed Identity."
raise ConflictingSettingsError(msg)
azure_auth_type: AzureAuthType | None = Field(
description="The Azure authentication type to use when using AOI.",
default=None,
)
type: LLMType = Field(description="The type of LLM model to use.")
model: str = Field(description="The LLM model to use.")
encoding_model: str = Field(description="The encoding model to use", default="")
def _validate_encoding_model(self) -> None:
"""Validate the encoding model.
Raises
------
KeyError
If the model name is not recognized.
"""
if self.encoding_model.strip() == "":
self.encoding_model = tiktoken.encoding_name_for_model(self.model)
max_tokens: int = Field(
description="The maximum number of tokens to generate.",
default=defs.LLM_MAX_TOKENS,
)
temperature: float = Field(
description="The temperature to use for token generation.",
default=defs.LLM_TEMPERATURE,
)
top_p: float = Field(
description="The top-p value to use for token generation.",
default=defs.LLM_TOP_P,
)
n: int = Field(
description="The number of completions to generate.",
default=defs.LLM_N,
)
frequency_penalty: float = Field(
description="The frequency penalty to use for token generation.",
default=defs.LLM_FREQUENCY_PENALTY,
)
presence_penalty: float = Field(
description="The presence penalty to use for token generation.",
default=defs.LLM_PRESENCE_PENALTY,
)
request_timeout: float = Field(
description="The request timeout to use.", default=defs.LLM_REQUEST_TIMEOUT
)
api_base: str | None = Field(
description="The base URL for the LLM API.", default=None
)
def _validate_api_base(self) -> None:
"""Validate the API base.
Required when using AOI.
Raises
------
AzureApiBaseMissingError
If the API base is missing and is required.
"""
if (
self.type == LLMType.AzureOpenAIChat
or self.type == LLMType.AzureOpenAIEmbedding
) and (self.api_base is None or self.api_base.strip() == ""):
raise AzureApiBaseMissingError(self.type.value)
api_version: str | None = Field(
description="The version of the LLM API to use.", default=None
)
def _validate_api_version(self) -> None:
"""Validate the API version.
Required when using AOI.
Raises
------
AzureApiBaseMissingError
If the API base is missing and is required.
"""
if (
self.type == LLMType.AzureOpenAIChat
or self.type == LLMType.AzureOpenAIEmbedding
) and (self.api_version is None or self.api_version.strip() == ""):
raise AzureApiVersionMissingError(self.type.value)
deployment_name: str | None = Field(
description="The deployment name to use for the LLM service.", default=None
)
def _validate_deployment_name(self) -> None:
"""Validate the deployment name.
Required when using AOI.
Raises
------
AzureDeploymentNameMissingError
If the deployment name is missing and is required.
"""
if (
self.type == LLMType.AzureOpenAIChat
or self.type == LLMType.AzureOpenAIEmbedding
) and (self.deployment_name is None or self.deployment_name.strip() == ""):
raise AzureDeploymentNameMissingError(self.type.value)
organization: str | None = Field(
description="The organization to use for the LLM service.", default=None
)
proxy: str | None = Field(
description="The proxy to use for the LLM service.", default=None
)
audience: str | None = Field(
description="Azure resource URI to use with managed identity for the llm connection.",
default=None,
)
model_supports_json: bool | None = Field(
description="Whether the model supports JSON output mode.", default=None
)
tokens_per_minute: int = Field(
description="The number of tokens per minute to use for the LLM service.",
default=defs.LLM_TOKENS_PER_MINUTE,
)
requests_per_minute: int = Field(
description="The number of requests per minute to use for the LLM service.",
default=defs.LLM_REQUESTS_PER_MINUTE,
)
max_retries: int = Field(
description="The maximum number of retries to use for the LLM service.",
default=defs.LLM_MAX_RETRIES,
)
max_retry_wait: float = Field(
description="The maximum retry wait to use for the LLM service.",
default=defs.LLM_MAX_RETRY_WAIT,
)
sleep_on_rate_limit_recommendation: bool = Field(
description="Whether to sleep on rate limit recommendations.",
default=defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION,
)
concurrent_requests: int = Field(
description="Whether to use concurrent requests for the LLM service.",
default=defs.LLM_CONCURRENT_REQUESTS,
)
responses: list[str | BaseModel] | None = Field(
default=None, description="Static responses to use in mock mode."
)
parallelization_stagger: float = Field(
description="The stagger to use for the LLM service.",
default=defs.PARALLELIZATION_STAGGER,
)
parallelization_num_threads: int = Field(
description="The number of threads to use for the LLM service.",
default=defs.PARALLELIZATION_NUM_THREADS,
)
async_mode: AsyncType = Field(
description="The async mode to use.", default=defs.ASYNC_MODE
)
def _validate_azure_settings(self) -> None:
"""Validate the Azure settings.
Raises
------
AzureApiBaseMissingError
If the API base is missing and is required.
AzureApiVersionMissingError
If the API version is missing and is required.
AzureDeploymentNameMissingError
If the deployment name is missing and is required.
"""
self._validate_api_base()
self._validate_api_version()
self._validate_deployment_name()
@model_validator(mode="after")
def _validate_model(self):
self._validate_api_key()
self._validate_azure_settings()
self._validate_encoding_model()
return self

View File

@ -1,26 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Parameterization settings for the default configuration."""
from pydantic import BaseModel, Field
import graphrag.config.defaults as defs
from graphrag.config.enums import AsyncType
from graphrag.config.models.llm_parameters import LLMParameters
from graphrag.config.models.parallelization_parameters import ParallelizationParameters
class LLMConfig(BaseModel):
"""Base class for LLM-configured steps."""
llm: LLMParameters = Field(
description="The LLM configuration to use.", default=LLMParameters()
)
parallelization: ParallelizationParameters = Field(
description="The parallelization configuration to use.",
default=ParallelizationParameters(),
)
async_mode: AsyncType = Field(
description="The async mode to use.", default=defs.ASYNC_MODE
)

View File

@ -1,102 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""LLM Parameters model."""
from pydantic import BaseModel, ConfigDict, Field
import graphrag.config.defaults as defs
from graphrag.config.enums import LLMType
class LLMParameters(BaseModel):
"""LLM Parameters model."""
model_config = ConfigDict(protected_namespaces=(), extra="allow")
api_key: str | None = Field(
description="The API key to use for the LLM service.",
default=None,
)
type: LLMType = Field(
description="The type of LLM model to use.", default=defs.LLM_TYPE
)
encoding_model: str | None = Field(
description="The encoding model to use", default=defs.ENCODING_MODEL
)
model: str = Field(description="The LLM model to use.", default=defs.LLM_MODEL)
max_tokens: int | None = Field(
description="The maximum number of tokens to generate.",
default=defs.LLM_MAX_TOKENS,
)
temperature: float | None = Field(
description="The temperature to use for token generation.",
default=defs.LLM_TEMPERATURE,
)
top_p: float | None = Field(
description="The top-p value to use for token generation.",
default=defs.LLM_TOP_P,
)
n: int | None = Field(
description="The number of completions to generate.",
default=defs.LLM_N,
)
frequency_penalty: float | None = Field(
description="The frequency penalty to use for token generation.",
default=defs.LLM_FREQUENCY_PENALTY,
)
presence_penalty: float | None = Field(
description="The presence penalty to use for token generation.",
default=defs.LLM_PRESENCE_PENALTY,
)
request_timeout: float = Field(
description="The request timeout to use.", default=defs.LLM_REQUEST_TIMEOUT
)
api_base: str | None = Field(
description="The base URL for the LLM API.", default=None
)
api_version: str | None = Field(
description="The version of the LLM API to use.", default=None
)
organization: str | None = Field(
description="The organization to use for the LLM service.", default=None
)
proxy: str | None = Field(
description="The proxy to use for the LLM service.", default=None
)
audience: str | None = Field(
description="Azure resource URI to use with managed identity for the llm connection.",
default=None,
)
deployment_name: str | None = Field(
description="The deployment name to use for the LLM service.", default=None
)
model_supports_json: bool | None = Field(
description="Whether the model supports JSON output mode.", default=None
)
tokens_per_minute: int = Field(
description="The number of tokens per minute to use for the LLM service.",
default=defs.LLM_TOKENS_PER_MINUTE,
)
requests_per_minute: int = Field(
description="The number of requests per minute to use for the LLM service.",
default=defs.LLM_REQUESTS_PER_MINUTE,
)
max_retries: int = Field(
description="The maximum number of retries to use for the LLM service.",
default=defs.LLM_MAX_RETRIES,
)
max_retry_wait: float = Field(
description="The maximum retry wait to use for the LLM service.",
default=defs.LLM_MAX_RETRY_WAIT,
)
sleep_on_rate_limit_recommendation: bool = Field(
description="Whether to sleep on rate limit recommendations.",
default=defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION,
)
concurrent_requests: int = Field(
description="Whether to use concurrent requests for the LLM service.",
default=defs.LLM_CONCURRENT_REQUESTS,
)
responses: list[str | BaseModel] | None = Field(
default=None, description="Static responses to use in mock mode."
)

View File

@ -34,15 +34,15 @@ class LocalSearchConfig(BaseModel):
description="The top k mapped relations.",
default=defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS,
)
temperature: float | None = Field(
temperature: float = Field(
description="The temperature to use for token generation.",
default=defs.LOCAL_SEARCH_LLM_TEMPERATURE,
)
top_p: float | None = Field(
top_p: float = Field(
description="The top-p value to use for token generation.",
default=defs.LOCAL_SEARCH_LLM_TOP_P,
)
n: int | None = Field(
n: int = Field(
description="The number of completions to generate.",
default=defs.LOCAL_SEARCH_LLM_N,
)

View File

@ -6,18 +6,18 @@
from pydantic import BaseModel, Field
import graphrag.config.defaults as defs
from graphrag.config.enums import StorageType
from graphrag.config.enums import OutputType
class StorageConfig(BaseModel):
"""The default configuration section for Storage."""
class OutputConfig(BaseModel):
"""The default configuration section for Output."""
type: StorageType = Field(
description="The storage type to use.", default=defs.STORAGE_TYPE
type: OutputType = Field(
description="The output type to use.", default=defs.OUTPUT_TYPE
)
base_dir: str = Field(
description="The base directory for the storage.",
default=defs.STORAGE_BASE_DIR,
description="The base directory for the output.",
default=defs.OUTPUT_BASE_DIR,
)
connection_string: str | None = Field(
description="The storage connection string to use.", default=None

View File

@ -1,21 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""LLM Parameters model."""
from pydantic import BaseModel, Field
import graphrag.config.defaults as defs
class ParallelizationParameters(BaseModel):
"""LLM Parameters model."""
stagger: float = Field(
description="The stagger to use for the LLM service.",
default=defs.PARALLELIZATION_STAGGER,
)
num_threads: int = Field(
description="The number of threads to use for the LLM service.",
default=defs.PARALLELIZATION_NUM_THREADS,
)

View File

@ -5,13 +5,13 @@
from pathlib import Path
from pydantic import Field
from pydantic import BaseModel, Field
import graphrag.config.defaults as defs
from graphrag.config.models.llm_config import LLMConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
class SummarizeDescriptionsConfig(LLMConfig):
class SummarizeDescriptionsConfig(BaseModel):
"""Configuration section for description summarization."""
prompt: str | None = Field(
@ -24,8 +24,14 @@ class SummarizeDescriptionsConfig(LLMConfig):
strategy: dict | None = Field(
description="The override strategy to use.", default=None
)
model_id: str = Field(
description="The model ID to use for summarization.",
default=defs.SUMMARIZE_MODEL_ID,
)
def resolved_strategy(self, root_dir: str) -> dict:
def resolved_strategy(
self, root_dir: str, model_config: LanguageModelConfig
) -> dict:
"""Get the resolved description summarization strategy."""
from graphrag.index.operations.summarize_descriptions import (
SummarizeStrategyType,
@ -33,11 +39,12 @@ class SummarizeDescriptionsConfig(LLMConfig):
return self.strategy or {
"type": SummarizeStrategyType.graph_intelligence,
"llm": self.llm.model_dump(),
**self.parallelization.model_dump(),
"summarize_prompt": (Path(root_dir) / self.prompt)
.read_bytes()
.decode(encoding="utf-8")
"llm": model_config.model_dump(),
"stagger": model_config.parallelization_stagger,
"num_threads": model_config.parallelization_num_threads,
"summarize_prompt": (Path(root_dir) / self.prompt).read_text(
encoding="utf-8"
)
if self.prompt
else None,
"max_summary_length": self.max_length,

View File

@ -3,14 +3,14 @@
"""Parameterization settings for the default configuration."""
from pydantic import Field
from pydantic import BaseModel, Field
import graphrag.config.defaults as defs
from graphrag.config.enums import TextEmbeddingTarget
from graphrag.config.models.llm_config import LLMConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
class TextEmbeddingConfig(LLMConfig):
class TextEmbeddingConfig(BaseModel):
"""Configuration section for text embeddings."""
batch_size: int = Field(
@ -21,18 +21,21 @@ class TextEmbeddingConfig(LLMConfig):
default=defs.EMBEDDING_BATCH_MAX_TOKENS,
)
target: TextEmbeddingTarget = Field(
description="The target to use. 'all' or 'required'.",
description="The target to use. 'all', 'required', 'selected', or 'none'.",
default=defs.EMBEDDING_TARGET,
)
skip: list[str] = Field(description="The specific embeddings to skip.", default=[])
vector_store: dict | None = Field(
description="The vector storage configuration", default=defs.VECTOR_STORE_DICT
names: list[str] = Field(
description="The specific embeddings to perform.", default=[]
)
strategy: dict | None = Field(
description="The override strategy to use.", default=None
)
model_id: str = Field(
description="The model ID to use for text embeddings.",
default=defs.EMBEDDING_MODEL_ID,
)
def resolved_strategy(self) -> dict:
def resolved_strategy(self, model_config: LanguageModelConfig) -> dict:
"""Get the resolved text embedding strategy."""
from graphrag.index.operations.embed_text import (
TextEmbedStrategyType,
@ -40,8 +43,9 @@ class TextEmbeddingConfig(LLMConfig):
return self.strategy or {
"type": TextEmbedStrategyType.openai,
"llm": self.llm.model_dump(),
**self.parallelization.model_dump(),
"llm": model_config.model_dump(),
"stagger": model_config.parallelization_stagger,
"num_threads": model_config.parallelization_num_threads,
"batch_size": self.batch_size,
"batch_max_tokens": self.batch_max_tokens,
}

View File

@ -0,0 +1,78 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Parameterization settings for the default configuration."""
from pydantic import BaseModel, Field, model_validator
import graphrag.config.defaults as defs
from graphrag.vector_stores.factory import VectorStoreType
class VectorStoreConfig(BaseModel):
"""The default configuration section for Vector Store."""
type: str = Field(
description="The vector store type to use.",
default=defs.VECTOR_STORE_TYPE,
)
db_uri: str | None = Field(description="The database URI to use.", default=None)
def _validate_db_uri(self) -> None:
"""Validate the database URI."""
if self.type == VectorStoreType.LanceDB.value and (
self.db_uri is None or self.db_uri.strip() == ""
):
self.db_uri = defs.VECTOR_STORE_DB_URI
if self.type != VectorStoreType.LanceDB.value and (
self.db_uri is not None and self.db_uri.strip() != ""
):
msg = "vector_store.db_uri is only used when vector_store.type == lancedb. Please rerun `graphrag init` and select the correct vector store type."
raise ValueError(msg)
url: str | None = Field(
description="The database URL when type == azure_ai_search.",
default=None,
)
def _validate_url(self) -> None:
"""Validate the database URL."""
if self.type == VectorStoreType.AzureAISearch and (
self.url is None or self.url.strip() == ""
):
msg = "vector_store.url is required when vector_store.type == azure_ai_search. Please rerun `graphrag init` and select the correct vector store type."
raise ValueError(msg)
if self.type != VectorStoreType.AzureAISearch and (
self.url is not None and self.url.strip() != ""
):
msg = "vector_store.url is only used when vector_store.type == azure_ai_search. Please rerun `graphrag init` and select the correct vector store type."
raise ValueError(msg)
api_key: str | None = Field(
description="The database API key when type == azure_ai_search.",
default=None,
)
audience: str | None = Field(
description="The database audience when type == azure_ai_search.",
default=None,
)
container_name: str = Field(
description="The database name to use.",
default=defs.VECTOR_STORE_CONTAINER_NAME,
)
overwrite: bool = Field(
description="Overwrite the existing data.", default=defs.VECTOR_STORE_OVERWRITE
)
@model_validator(mode="after")
def _validate_model(self):
"""Validate the model."""
self._validate_db_uri()
self._validate_url()
return self

View File

@ -1,214 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Resolve timestamp variables in a path."""
import re
from pathlib import Path
from string import Template
from graphrag.config.enums import ReportingType, StorageType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.vector_stores.factory import VectorStoreType
def _resolve_timestamp_path_with_value(path: str | Path, timestamp_value: str) -> Path:
"""Resolve the timestamp in the path with the given timestamp value.
Parameters
----------
path : str | Path
The path containing ${timestamp} variables to resolve.
timestamp_value : str
The timestamp value used to resolve the path.
Returns
-------
Path
The path with ${timestamp} variables resolved to the provided timestamp value.
"""
template = Template(str(path))
resolved_path = template.substitute(timestamp=timestamp_value)
return Path(resolved_path)
def _resolve_timestamp_path_with_dir(
path: str | Path, pattern: re.Pattern[str]
) -> Path:
"""Resolve the timestamp in the path with the latest available timestamp directory value.
Parameters
----------
path : str | Path
The path containing ${timestamp} variables to resolve.
pattern : re.Pattern[str]
The pattern to use to match the timestamp directories.
Returns
-------
Path
The path with ${timestamp} variables resolved to the latest available timestamp directory value.
Raises
------
ValueError
If the parent directory expecting to contain timestamp directories does not exist or is not a directory.
Or if no timestamp directories are found in the parent directory that match the pattern.
"""
path = Path(path)
path_parts = path.parts
parent_dir = Path(path_parts[0])
found_timestamp_pattern = False
for _, part in enumerate(path_parts[1:]):
if part.lower() == "${timestamp}":
found_timestamp_pattern = True
break
parent_dir = parent_dir / part
# Path not using timestamp layout.
if not found_timestamp_pattern:
return path
if not parent_dir.exists() or not parent_dir.is_dir():
msg = f"Parent directory {parent_dir} does not exist or is not a directory."
raise ValueError(msg)
timestamp_dirs = [
d for d in parent_dir.iterdir() if d.is_dir() and pattern.match(d.name)
]
timestamp_dirs.sort(key=lambda d: d.name, reverse=True)
if len(timestamp_dirs) == 0:
msg = f"No timestamp directories found in {parent_dir} that match {pattern.pattern}."
raise ValueError(msg)
return _resolve_timestamp_path_with_value(path, timestamp_dirs[0].name)
def _resolve_timestamp_path(
path: str | Path,
pattern_or_timestamp_value: re.Pattern[str] | str | None = None,
) -> Path:
r"""Timestamp path resolver.
Resolve the timestamp in the path with the given timestamp value or
with the latest available timestamp directory matching the given pattern.
Parameters
----------
path : str | Path
The path containing ${timestamp} variables to resolve.
pattern_or_timestamp_value : re.Pattern[str] | str, default=re.compile(r"^\d{8}-\d{6}$")
The pattern to use to match the timestamp directories or the timestamp value to use.
If a string is provided, the path will be resolved with the given string value.
Otherwise, the path will be resolved with the latest available timestamp directory
that matches the given pattern.
Returns
-------
Path
The path with ${timestamp} variables resolved to the provided timestamp value or
the latest available timestamp directory.
Raises
------
ValueError
If the parent directory expecting to contain timestamp directories does not exist or is not a directory.
Or if no timestamp directories are found in the parent directory that match the pattern.
"""
if not pattern_or_timestamp_value:
pattern_or_timestamp_value = re.compile(r"^\d{8}-\d{6}$")
if isinstance(pattern_or_timestamp_value, str):
return _resolve_timestamp_path_with_value(path, pattern_or_timestamp_value)
return _resolve_timestamp_path_with_dir(path, pattern_or_timestamp_value)
def resolve_path(
path_to_resolve: Path | str,
root_dir: Path | str | None = None,
pattern_or_timestamp_value: re.Pattern[str] | str | None = None,
) -> Path:
"""Resolve the path.
Resolves any timestamp variables by either using the provided timestamp value if string or
by looking up the latest available timestamp directory that matches the given pattern.
Resolves the path against the root directory if provided.
Parameters
----------
path_to_resolve : Path | str
The path to resolve.
root_dir : Path | str | None default=None
The root directory to resolve the path from, if provided.
pattern_or_timestamp_value : re.Pattern[str] | str, default=None
The pattern to use to match the timestamp directories or the timestamp value to use.
If a string is provided, the path will be resolved with the given string value.
Otherwise, the path will be resolved with the latest available timestamp directory
that matches the given pattern.
Returns
-------
Path
The resolved path.
"""
if root_dir:
path_to_resolve = (Path(root_dir) / path_to_resolve).resolve()
else:
path_to_resolve = Path(path_to_resolve)
return _resolve_timestamp_path(path_to_resolve, pattern_or_timestamp_value)
def resolve_paths(
config: GraphRagConfig,
pattern_or_timestamp_value: re.Pattern[str] | str | None = None,
) -> None:
"""Resolve storage and reporting paths in the configuration for local file handling.
Resolves any timestamp variables in the configuration paths by either using the provided timestamp value if string or
by looking up the latest available timestamp directory that matches the given pattern.
Parameters
----------
config : GraphRagConfig
The configuration to resolve the paths in.
pattern_or_timestamp_value : re.Pattern[str] | str, default=None
The pattern to use to match the timestamp directories or the timestamp value to use.
If a string is provided, the path will be resolved with the given string value.
Otherwise, the path will be resolved with the latest available timestamp directory
that matches the given pattern.
"""
if config.storage.type == StorageType.file:
config.storage.base_dir = str(
resolve_path(
config.storage.base_dir,
config.root_dir,
pattern_or_timestamp_value,
)
)
if (
config.update_index_storage
and config.update_index_storage.type == StorageType.file
):
config.update_index_storage.base_dir = str(
resolve_path(
config.update_index_storage.base_dir,
config.root_dir,
pattern_or_timestamp_value,
)
)
if config.reporting.type == ReportingType.file:
config.reporting.base_dir = str(
resolve_path(
config.reporting.base_dir,
config.root_dir,
pattern_or_timestamp_value,
)
)
# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
# TODO: remove the type ignore annotations below once the new config engine has been refactored
vector_store_type = config.embeddings.vector_store["type"] # type: ignore
if vector_store_type == VectorStoreType.LanceDB:
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
lancedb_dir = Path(config.root_dir).resolve() / db_uri
config.embeddings.vector_store["db_uri"] = str(lancedb_dir) # type: ignore

View File

@ -1,4 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The Indexing Engine config typing package root."""

View File

@ -1,105 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'PipelineCacheConfig', 'PipelineFileCacheConfig', 'PipelineMemoryCacheConfig', 'PipelineBlobCacheConfig', 'PipelineCosmosDBCacheConfig' models."""
from __future__ import annotations
from typing import Generic, Literal, TypeVar
from pydantic import BaseModel, Field
from graphrag.config.enums import CacheType
T = TypeVar("T")
class PipelineCacheConfig(BaseModel, Generic[T]):
"""Represent the cache configuration for the pipeline."""
type: T
class PipelineFileCacheConfig(PipelineCacheConfig[Literal[CacheType.file]]):
"""Represent the file cache configuration for the pipeline."""
type: Literal[CacheType.file] = CacheType.file
"""The type of cache."""
base_dir: str | None = Field(
description="The base directory for the cache.", default=None
)
"""The base directory for the cache."""
class PipelineMemoryCacheConfig(PipelineCacheConfig[Literal[CacheType.memory]]):
"""Represent the memory cache configuration for the pipeline."""
type: Literal[CacheType.memory] = CacheType.memory
"""The type of cache."""
class PipelineNoneCacheConfig(PipelineCacheConfig[Literal[CacheType.none]]):
"""Represent the none cache configuration for the pipeline."""
type: Literal[CacheType.none] = CacheType.none
"""The type of cache."""
class PipelineBlobCacheConfig(PipelineCacheConfig[Literal[CacheType.blob]]):
"""Represents the blob cache configuration for the pipeline."""
type: Literal[CacheType.blob] = CacheType.blob
"""The type of cache."""
base_dir: str | None = Field(
description="The base directory for the cache.", default=None
)
"""The base directory for the cache."""
connection_string: str | None = Field(
description="The blob cache connection string for the cache.", default=None
)
"""The blob cache connection string for the cache."""
container_name: str = Field(description="The container name for cache", default="")
"""The container name for cache"""
storage_account_blob_url: str | None = Field(
description="The storage account blob url for cache", default=None
)
"""The storage account blob url for cache"""
class PipelineCosmosDBCacheConfig(PipelineCacheConfig[Literal[CacheType.cosmosdb]]):
"""Represents the cosmosdb cache configuration for the pipeline."""
type: Literal[CacheType.cosmosdb] = CacheType.cosmosdb
"""The type of cache."""
base_dir: str | None = Field(
description="The cosmosdb database name for the cache.", default=None
)
"""The cosmosdb database name for the cache."""
container_name: str = Field(description="The container name for cache.", default="")
"""The container name for cache."""
connection_string: str | None = Field(
description="The cosmosdb primary key for the cache.", default=None
)
"""The cosmosdb primary key for the cache."""
cosmosdb_account_url: str | None = Field(
description="The cosmosdb account url for cache", default=None
)
"""The cosmosdb account url for cache"""
PipelineCacheConfigTypes = (
PipelineFileCacheConfig
| PipelineMemoryCacheConfig
| PipelineBlobCacheConfig
| PipelineNoneCacheConfig
| PipelineCosmosDBCacheConfig
)

View File

@ -1,110 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'PipelineInputConfig', 'PipelineCSVInputConfig' and 'PipelineTextInputConfig' models."""
from __future__ import annotations
from typing import Generic, Literal, TypeVar
from pydantic import BaseModel, Field
from graphrag.config.enums import InputFileType, InputType
T = TypeVar("T")
class PipelineInputConfig(BaseModel, Generic[T]):
"""Represent the configuration for an input."""
file_type: T
"""The file type of input."""
type: InputType | None = Field(
description="The input type to use.",
default=None,
)
"""The input type to use."""
connection_string: str | None = Field(
description="The blob cache connection string for the input files.",
default=None,
)
"""The blob cache connection string for the input files."""
storage_account_blob_url: str | None = Field(
description="The storage account blob url for the input files.", default=None
)
"""The storage account blob url for the input files."""
container_name: str | None = Field(
description="The container name for input files.", default=None
)
"""The container name for the input files."""
base_dir: str | None = Field(
description="The base directory for the input files.", default=None
)
"""The base directory for the input files."""
file_pattern: str = Field(description="The regex file pattern for the input files.")
"""The regex file pattern for the input files."""
file_filter: dict[str, str] | None = Field(
description="The optional file filter for the input files.", default=None
)
"""The optional file filter for the input files."""
encoding: str | None = Field(
description="The encoding for the input files.", default=None
)
"""The encoding for the input files."""
class PipelineCSVInputConfig(PipelineInputConfig[Literal[InputFileType.csv]]):
"""Represent the configuration for a CSV input."""
file_type: Literal[InputFileType.csv] = InputFileType.csv
source_column: str | None = Field(
description="The column to use as the source of the document.", default=None
)
"""The column to use as the source of the document."""
timestamp_column: str | None = Field(
description="The column to use as the timestamp of the document.", default=None
)
"""The column to use as the timestamp of the document."""
timestamp_format: str | None = Field(
description="The format of the timestamp column, so it can be parsed correctly.",
default=None,
)
"""The format of the timestamp column, so it can be parsed correctly."""
text_column: str | None = Field(
description="The column to use as the text of the document.", default=None
)
"""The column to use as the text of the document."""
title_column: str | None = Field(
description="The column to use as the title of the document.", default=None
)
"""The column to use as the title of the document."""
class PipelineTextInputConfig(PipelineInputConfig[Literal[InputFileType.text]]):
"""Represent the configuration for a text input."""
file_type: Literal[InputFileType.text] = InputFileType.text
# Text Specific
title_text_length: int | None = Field(
description="Number of characters to use from the text as the title.",
default=None,
)
"""Number of characters to use from the text as the title."""
PipelineInputConfigTypes = PipelineCSVInputConfig | PipelineTextInputConfig
"""Represent the types of inputs that can be used in a pipeline."""

View File

@ -1,66 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'PipelineConfig' model."""
from __future__ import annotations
from devtools import pformat
from pydantic import BaseModel, Field
from graphrag.index.config.cache import PipelineCacheConfigTypes
from graphrag.index.config.input import PipelineInputConfigTypes
from graphrag.index.config.reporting import PipelineReportingConfigTypes
from graphrag.index.config.storage import PipelineStorageConfigTypes
from graphrag.index.config.workflow import PipelineWorkflowReference
class PipelineConfig(BaseModel):
"""Represent the configuration for a pipeline."""
def __repr__(self) -> str:
"""Get a string representation."""
return pformat(self, highlight=False)
def __str__(self):
"""Get a string representation."""
return str(self.model_dump_json(indent=4))
extends: list[str] | str | None = Field(
description="Extends another pipeline configuration", default=None
)
"""Extends another pipeline configuration"""
input: PipelineInputConfigTypes | None = Field(
default=None, discriminator="file_type"
)
"""The input configuration for the pipeline."""
reporting: PipelineReportingConfigTypes | None = Field(
default=None, discriminator="type"
)
"""The reporting configuration for the pipeline."""
storage: PipelineStorageConfigTypes | None = Field(
default=None, discriminator="type"
)
"""The storage configuration for the pipeline."""
update_index_storage: PipelineStorageConfigTypes | None = Field(
default=None, discriminator="type"
)
"""The storage configuration for the updated index."""
cache: PipelineCacheConfigTypes | None = Field(default=None, discriminator="type")
"""The cache configuration for the pipeline."""
root_dir: str | None = Field(
description="The root directory for the pipeline. All other paths will be based on this root_dir.",
default=None,
)
"""The root directory for the pipeline."""
workflows: list[PipelineWorkflowReference] = Field(
description="The workflows for the pipeline.", default_factory=list
)
"""The workflows for the pipeline."""

View File

@ -1,103 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'PipelineStorageConfig', 'PipelineFileStorageConfig','PipelineMemoryStorageConfig', 'PipelineBlobStorageConfig', and 'PipelineCosmosDBStorageConfig' models."""
from __future__ import annotations
from typing import Generic, Literal, TypeVar
from pydantic import BaseModel, Field
from graphrag.config.enums import StorageType
T = TypeVar("T")
class PipelineStorageConfig(BaseModel, Generic[T]):
"""Represent the storage configuration for the pipeline."""
type: T
class PipelineFileStorageConfig(PipelineStorageConfig[Literal[StorageType.file]]):
"""Represent the file storage configuration for the pipeline."""
type: Literal[StorageType.file] = StorageType.file
"""The type of storage."""
base_dir: str | None = Field(
description="The base directory for the storage.", default=None
)
"""The base directory for the storage."""
class PipelineMemoryStorageConfig(PipelineStorageConfig[Literal[StorageType.memory]]):
"""Represent the memory storage configuration for the pipeline."""
type: Literal[StorageType.memory] = StorageType.memory
"""The type of storage."""
class PipelineBlobStorageConfig(PipelineStorageConfig[Literal[StorageType.blob]]):
"""Represents the blob storage configuration for the pipeline."""
type: Literal[StorageType.blob] = StorageType.blob
"""The type of storage."""
connection_string: str | None = Field(
description="The blob storage connection string for the storage.", default=None
)
"""The blob storage connection string for the storage."""
container_name: str = Field(
description="The container name for storage", default=""
)
"""The container name for storage."""
base_dir: str | None = Field(
description="The base directory for the storage.", default=None
)
"""The base directory for the storage."""
storage_account_blob_url: str | None = Field(
description="The storage account blob url.", default=None
)
"""The storage account blob url."""
class PipelineCosmosDBStorageConfig(
PipelineStorageConfig[Literal[StorageType.cosmosdb]]
):
"""Represents the cosmosdb storage configuration for the pipeline."""
type: Literal[StorageType.cosmosdb] = StorageType.cosmosdb
"""The type of storage."""
connection_string: str | None = Field(
description="The cosmosdb storage primary key for the storage.", default=None
)
"""The cosmosdb storage primary key for the storage."""
container_name: str = Field(
description="The container name for storage", default=""
)
"""The container name for storage."""
base_dir: str | None = Field(
description="The base directory for the storage.", default=None
)
"""The base directory for the storage."""
cosmosdb_account_url: str | None = Field(
description="The cosmosdb account url.", default=None
)
"""The cosmosdb account url."""
PipelineStorageConfigTypes = (
PipelineFileStorageConfig
| PipelineMemoryStorageConfig
| PipelineBlobStorageConfig
| PipelineCosmosDBStorageConfig
)

View File

@ -1,25 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'PipelineWorkflowReference' model."""
from __future__ import annotations
from typing import Any
from pydantic import BaseModel, Field
PipelineWorkflowConfig = dict[str, Any]
"""Represent a configuration for a workflow."""
class PipelineWorkflowReference(BaseModel):
"""Represent a reference to a workflow, and can optionally be the workflow itself."""
name: str | None = Field(description="Name of the workflow.", default=None)
"""Name of the workflow."""
config: PipelineWorkflowConfig | None = Field(
description="The optional configuration for the workflow.", default=None
)
"""The optional configuration for the workflow."""

View File

@ -1,456 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Default configuration methods definition."""
import json
import logging
from pathlib import Path
from graphrag.config.enums import (
CacheType,
InputFileType,
ReportingType,
StorageType,
)
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.storage_config import StorageConfig
from graphrag.index.config.cache import (
PipelineBlobCacheConfig,
PipelineCacheConfigTypes,
PipelineCosmosDBCacheConfig,
PipelineFileCacheConfig,
PipelineMemoryCacheConfig,
PipelineNoneCacheConfig,
)
from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings
from graphrag.index.config.input import (
PipelineCSVInputConfig,
PipelineInputConfigTypes,
PipelineTextInputConfig,
)
from graphrag.index.config.pipeline import (
PipelineConfig,
)
from graphrag.index.config.reporting import (
PipelineBlobReportingConfig,
PipelineConsoleReportingConfig,
PipelineFileReportingConfig,
PipelineReportingConfigTypes,
)
from graphrag.index.config.storage import (
PipelineBlobStorageConfig,
PipelineCosmosDBStorageConfig,
PipelineFileStorageConfig,
PipelineMemoryStorageConfig,
PipelineStorageConfigTypes,
)
from graphrag.index.config.workflow import (
PipelineWorkflowReference,
)
from graphrag.index.workflows import (
compute_communities,
create_base_text_units,
create_final_communities,
create_final_community_reports,
create_final_covariates,
create_final_documents,
create_final_entities,
create_final_nodes,
create_final_relationships,
create_final_text_units,
extract_graph,
generate_text_embeddings,
)
log = logging.getLogger(__name__)
builtin_document_attributes: set[str] = {
"id",
"source",
"text",
"title",
"timestamp",
"year",
"month",
"day",
"hour",
"minute",
"second",
}
def create_pipeline_config(settings: GraphRagConfig, verbose=False) -> PipelineConfig:
"""Get the default config for the pipeline."""
# relative to the root_dir
if verbose:
_log_llm_settings(settings)
skip_workflows = settings.skip_workflows
embedded_fields = get_embedded_fields(settings)
covariates_enabled = (
settings.claim_extraction.enabled
and create_final_covariates not in skip_workflows
)
result = PipelineConfig(
root_dir=settings.root_dir,
input=_get_pipeline_input_config(settings),
reporting=_get_reporting_config(settings),
storage=_get_storage_config(settings, settings.storage),
update_index_storage=_get_storage_config(
settings, settings.update_index_storage
),
cache=_get_cache_config(settings),
workflows=[
*_document_workflows(settings),
*_text_unit_workflows(settings, covariates_enabled),
*_graph_workflows(settings),
*_community_workflows(settings, covariates_enabled),
*(_covariate_workflows(settings) if covariates_enabled else []),
*(_embeddings_workflows(settings, embedded_fields)),
],
)
# Remove any workflows that were specified to be skipped
log.info("skipping workflows %s", ",".join(skip_workflows))
result.workflows = [w for w in result.workflows if w.name not in skip_workflows]
return result
def _log_llm_settings(settings: GraphRagConfig) -> None:
log.info(
"Using LLM Config %s",
json.dumps(
{**settings.entity_extraction.llm.model_dump(), "api_key": "*****"},
indent=4,
),
)
log.info(
"Using Embeddings Config %s",
json.dumps(
{**settings.embeddings.llm.model_dump(), "api_key": "*****"}, indent=4
),
)
def _document_workflows(
settings: GraphRagConfig,
) -> list[PipelineWorkflowReference]:
return [
PipelineWorkflowReference(
name=create_final_documents,
config={
"document_attribute_columns": list(
{*(settings.input.document_attribute_columns)}
- builtin_document_attributes
),
},
),
]
def _text_unit_workflows(
settings: GraphRagConfig,
covariates_enabled: bool,
) -> list[PipelineWorkflowReference]:
return [
PipelineWorkflowReference(
name=create_base_text_units,
config={
"chunks": settings.chunks,
"snapshot_transient": settings.snapshots.transient,
},
),
PipelineWorkflowReference(
name=create_final_text_units,
config={
"covariates_enabled": covariates_enabled,
},
),
]
def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference]:
return [
PipelineWorkflowReference(
name=extract_graph,
config={
"snapshot_graphml": settings.snapshots.graphml,
"snapshot_transient": settings.snapshots.transient,
"entity_extract": {
**settings.entity_extraction.parallelization.model_dump(),
"async_mode": settings.entity_extraction.async_mode,
"strategy": settings.entity_extraction.resolved_strategy(
settings.root_dir, settings.encoding_model
),
"entity_types": settings.entity_extraction.entity_types,
},
"summarize_descriptions": {
**settings.summarize_descriptions.parallelization.model_dump(),
"async_mode": settings.summarize_descriptions.async_mode,
"strategy": settings.summarize_descriptions.resolved_strategy(
settings.root_dir,
),
},
},
),
PipelineWorkflowReference(
name=compute_communities,
config={
"cluster_graph": settings.cluster_graph,
"snapshot_transient": settings.snapshots.transient,
},
),
PipelineWorkflowReference(
name=create_final_entities,
config={},
),
PipelineWorkflowReference(
name=create_final_relationships,
config={},
),
PipelineWorkflowReference(
name=create_final_nodes,
config={
"layout_enabled": settings.umap.enabled,
"embed_graph": settings.embed_graph,
},
),
]
def _community_workflows(
settings: GraphRagConfig, covariates_enabled: bool
) -> list[PipelineWorkflowReference]:
return [
PipelineWorkflowReference(name=create_final_communities),
PipelineWorkflowReference(
name=create_final_community_reports,
config={
"covariates_enabled": covariates_enabled,
"create_community_reports": {
**settings.community_reports.parallelization.model_dump(),
"async_mode": settings.community_reports.async_mode,
"strategy": settings.community_reports.resolved_strategy(
settings.root_dir
),
},
},
),
]
def _covariate_workflows(
settings: GraphRagConfig,
) -> list[PipelineWorkflowReference]:
return [
PipelineWorkflowReference(
name=create_final_covariates,
config={
"claim_extract": {
**settings.claim_extraction.parallelization.model_dump(),
"strategy": settings.claim_extraction.resolved_strategy(
settings.root_dir, settings.encoding_model
),
},
},
)
]
def _embeddings_workflows(
settings: GraphRagConfig, embedded_fields: set[str]
) -> list[PipelineWorkflowReference]:
return [
PipelineWorkflowReference(
name=generate_text_embeddings,
config={
"snapshot_embeddings": settings.snapshots.embeddings,
"text_embed": get_embedding_settings(settings.embeddings),
"embedded_fields": embedded_fields,
},
),
]
def _get_pipeline_input_config(
settings: GraphRagConfig,
) -> PipelineInputConfigTypes:
file_type = settings.input.file_type
match file_type:
case InputFileType.csv:
return PipelineCSVInputConfig(
base_dir=settings.input.base_dir,
file_pattern=settings.input.file_pattern,
encoding=settings.input.encoding,
source_column=settings.input.source_column,
timestamp_column=settings.input.timestamp_column,
timestamp_format=settings.input.timestamp_format,
text_column=settings.input.text_column,
title_column=settings.input.title_column,
type=settings.input.type,
connection_string=settings.input.connection_string,
storage_account_blob_url=settings.input.storage_account_blob_url,
container_name=settings.input.container_name,
)
case InputFileType.text:
return PipelineTextInputConfig(
base_dir=settings.input.base_dir,
file_pattern=settings.input.file_pattern,
encoding=settings.input.encoding,
type=settings.input.type,
connection_string=settings.input.connection_string,
storage_account_blob_url=settings.input.storage_account_blob_url,
container_name=settings.input.container_name,
)
case _:
msg = f"Unknown input type: {file_type}"
raise ValueError(msg)
def _get_reporting_config(
settings: GraphRagConfig,
) -> PipelineReportingConfigTypes:
"""Get the reporting config from the settings."""
match settings.reporting.type:
case ReportingType.file:
# relative to the root_dir
return PipelineFileReportingConfig(base_dir=settings.reporting.base_dir)
case ReportingType.blob:
connection_string = settings.reporting.connection_string
storage_account_blob_url = settings.reporting.storage_account_blob_url
container_name = settings.reporting.container_name
if container_name is None:
msg = "Container name must be provided for blob reporting."
raise ValueError(msg)
if connection_string is None and storage_account_blob_url is None:
msg = "Connection string or storage account blob url must be provided for blob reporting."
raise ValueError(msg)
return PipelineBlobReportingConfig(
connection_string=connection_string,
container_name=container_name,
base_dir=settings.reporting.base_dir,
storage_account_blob_url=storage_account_blob_url,
)
case ReportingType.console:
return PipelineConsoleReportingConfig()
case _:
# relative to the root_dir
return PipelineFileReportingConfig(base_dir=settings.reporting.base_dir)
def _get_storage_config(
settings: GraphRagConfig,
storage_settings: StorageConfig | None,
) -> PipelineStorageConfigTypes | None:
"""Get the storage type from the settings."""
if not storage_settings:
return None
root_dir = settings.root_dir
match storage_settings.type:
case StorageType.memory:
return PipelineMemoryStorageConfig()
case StorageType.file:
# relative to the root_dir
base_dir = storage_settings.base_dir
if base_dir is None:
msg = "Base directory must be provided for file storage."
raise ValueError(msg)
return PipelineFileStorageConfig(base_dir=str(Path(root_dir) / base_dir))
case StorageType.blob:
connection_string = storage_settings.connection_string
storage_account_blob_url = storage_settings.storage_account_blob_url
container_name = storage_settings.container_name
if container_name is None:
msg = "Container name must be provided for blob storage."
raise ValueError(msg)
if connection_string is None and storage_account_blob_url is None:
msg = "Connection string or storage account blob url must be provided for blob storage."
raise ValueError(msg)
return PipelineBlobStorageConfig(
connection_string=connection_string,
container_name=container_name,
base_dir=storage_settings.base_dir,
storage_account_blob_url=storage_account_blob_url,
)
case StorageType.cosmosdb:
cosmosdb_account_url = storage_settings.cosmosdb_account_url
connection_string = storage_settings.connection_string
base_dir = storage_settings.base_dir
container_name = storage_settings.container_name
if connection_string is None and cosmosdb_account_url is None:
msg = "Connection string or cosmosDB account url must be provided for cosmosdb storage."
raise ValueError(msg)
if base_dir is None:
msg = "Base directory must be provided for cosmosdb storage."
raise ValueError(msg)
if container_name is None:
msg = "Container name must be provided for cosmosdb storage."
raise ValueError(msg)
return PipelineCosmosDBStorageConfig(
cosmosdb_account_url=cosmosdb_account_url,
connection_string=connection_string,
base_dir=storage_settings.base_dir,
container_name=container_name,
)
case _:
# relative to the root_dir
base_dir = storage_settings.base_dir
if base_dir is None:
msg = "Base directory must be provided for file storage."
raise ValueError(msg)
return PipelineFileStorageConfig(base_dir=str(Path(root_dir) / base_dir))
def _get_cache_config(
settings: GraphRagConfig,
) -> PipelineCacheConfigTypes:
"""Get the cache type from the settings."""
match settings.cache.type:
case CacheType.memory:
return PipelineMemoryCacheConfig()
case CacheType.file:
# relative to root dir
return PipelineFileCacheConfig(base_dir=settings.cache.base_dir)
case CacheType.none:
return PipelineNoneCacheConfig()
case CacheType.blob:
connection_string = settings.cache.connection_string
storage_account_blob_url = settings.cache.storage_account_blob_url
container_name = settings.cache.container_name
if container_name is None:
msg = "Container name must be provided for blob cache."
raise ValueError(msg)
if connection_string is None and storage_account_blob_url is None:
msg = "Connection string or storage account blob url must be provided for blob cache."
raise ValueError(msg)
return PipelineBlobCacheConfig(
connection_string=connection_string,
container_name=container_name,
base_dir=settings.cache.base_dir,
storage_account_blob_url=storage_account_blob_url,
)
case CacheType.cosmosdb:
cosmosdb_account_url = settings.cache.cosmosdb_account_url
connection_string = settings.cache.connection_string
base_dir = settings.cache.base_dir
container_name = settings.cache.container_name
if base_dir is None:
msg = "Base directory must be provided for cosmosdb cache."
raise ValueError(msg)
if container_name is None:
msg = "Container name must be provided for cosmosdb cache."
raise ValueError(msg)
if connection_string is None and cosmosdb_account_url is None:
msg = "Connection string or cosmosDB account url must be provided for cosmosdb cache."
raise ValueError(msg)
return PipelineCosmosDBCacheConfig(
cosmosdb_account_url=cosmosdb_account_url,
connection_string=connection_string,
base_dir=base_dir,
container_name=container_name,
)
case _:
# relative to root dir
return PipelineFileCacheConfig(base_dir="./cache")

View File

@ -1,25 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""GraphRAG indexing error types."""
class NoWorkflowsDefinedError(ValueError):
"""Exception for no workflows defined."""
def __init__(self):
super().__init__("No workflows defined.")
class UndefinedWorkflowError(ValueError):
"""Exception for invalid verb input."""
def __init__(self):
super().__init__("Workflow name is undefined.")
class UnknownWorkflowError(ValueError):
"""Exception for invalid verb input."""
def __init__(self, name: str):
super().__init__(f"Unknown workflow: {name}")

View File

@ -70,12 +70,12 @@ async def create_final_community_reports(
)
community_reports = await summarize_communities(
local_contexts,
nodes,
community_hierarchy,
callbacks,
cache,
summarization_strategy,
local_contexts=local_contexts,
nodes=nodes,
community_hierarchy=community_hierarchy,
callbacks=callbacks,
cache=cache,
strategy=summarization_strategy,
async_mode=async_mode,
num_threads=num_threads,
)

View File

@ -31,15 +31,15 @@ async def create_final_covariates(
# this also results in text_unit_id being copied to the output covariate table
text_units["text_unit_id"] = text_units["id"]
covariates = await extract_covariates(
text_units,
callbacks,
cache,
"text",
covariate_type,
extraction_strategy,
async_mode,
entity_types,
num_threads,
input=text_units,
callbacks=callbacks,
cache=cache,
column="text",
covariate_type=covariate_type,
strategy=extraction_strategy,
async_mode=async_mode,
entity_types=entity_types,
num_threads=num_threads,
)
text_units.drop(columns=["text_unit_id"], inplace=True) # don't pollute the global
covariates["id"] = covariates["covariate_type"].apply(lambda _x: str(uuid4()))

View File

@ -31,9 +31,9 @@ async def extract_graph(
"""All the steps to create the base entity graph."""
# this returns a graph for each text unit, to be merged later
entities, relationships = await extract_entities(
text_units,
callbacks,
cache,
text_units=text_units,
callbacks=callbacks,
cache=cache,
text_column="text",
id_column="id",
strategy=extraction_strategy,
@ -55,10 +55,10 @@ async def extract_graph(
raise ValueError(error_msg)
entity_summaries, relationship_summaries = await summarize_descriptions(
entities,
relationships,
callbacks,
cache,
entities_df=entities,
relationships_df=relationships,
callbacks=callbacks,
cache=cache,
strategy=summarization_strategy,
num_threads=summarization_num_threads,
)

View File

@ -9,7 +9,7 @@ import pandas as pd
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.config.embeddings import (
from graphrag.config.embeddings import (
community_full_content_embedding,
community_summary_embedding,
community_title_embedding,
@ -119,9 +119,9 @@ async def _run_and_snapshot_embeddings(
"""All the steps to generate single embedding."""
if text_embed_config:
data["embedding"] = await embed_text(
data,
callbacks,
cache,
input=data,
callbacks=callbacks,
cache=cache,
embed_column=embed_column,
embedding_name=name,
strategy=text_embed_config["strategy"],

View File

@ -6,11 +6,10 @@
import logging
import re
from io import BytesIO
from typing import cast
import pandas as pd
from graphrag.index.config.input import PipelineCSVInputConfig, PipelineInputConfig
from graphrag.config.models.input_config import InputConfig
from graphrag.index.utils.hashing import gen_sha512_hash
from graphrag.logger.base import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
@ -23,19 +22,18 @@ input_type = "csv"
async def load(
config: PipelineInputConfig,
config: InputConfig,
progress: ProgressLogger | None,
storage: PipelineStorage,
) -> pd.DataFrame:
"""Load csv inputs from a directory."""
csv_config = cast("PipelineCSVInputConfig", config)
log.info("Loading csv files from %s", csv_config.base_dir)
log.info("Loading csv files from %s", config.base_dir)
async def load_file(path: str, group: dict | None) -> pd.DataFrame:
if group is None:
group = {}
buffer = BytesIO(await storage.get(path, as_bytes=True))
data = pd.read_csv(buffer, encoding=config.encoding or "latin-1")
data = pd.read_csv(buffer, encoding=config.encoding)
additional_keys = group.keys()
if len(additional_keys) > 0:
data[[*additional_keys]] = data.apply(
@ -43,51 +41,49 @@ async def load(
)
if "id" not in data.columns:
data["id"] = data.apply(lambda x: gen_sha512_hash(x, x.keys()), axis=1)
if csv_config.source_column is not None and "source" not in data.columns:
if csv_config.source_column not in data.columns:
if config.source_column is not None and "source" not in data.columns:
if config.source_column not in data.columns:
log.warning(
"source_column %s not found in csv file %s",
csv_config.source_column,
config.source_column,
path,
)
else:
data["source"] = data.apply(
lambda x: x[csv_config.source_column], axis=1
)
if csv_config.text_column is not None and "text" not in data.columns:
if csv_config.text_column not in data.columns:
data["source"] = data.apply(lambda x: x[config.source_column], axis=1)
if config.text_column is not None and "text" not in data.columns:
if config.text_column not in data.columns:
log.warning(
"text_column %s not found in csv file %s",
csv_config.text_column,
config.text_column,
path,
)
else:
data["text"] = data.apply(lambda x: x[csv_config.text_column], axis=1)
if csv_config.title_column is not None and "title" not in data.columns:
if csv_config.title_column not in data.columns:
data["text"] = data.apply(lambda x: x[config.text_column], axis=1)
if config.title_column is not None and "title" not in data.columns:
if config.title_column not in data.columns:
log.warning(
"title_column %s not found in csv file %s",
csv_config.title_column,
config.title_column,
path,
)
else:
data["title"] = data.apply(lambda x: x[csv_config.title_column], axis=1)
data["title"] = data.apply(lambda x: x[config.title_column], axis=1)
if csv_config.timestamp_column is not None:
fmt = csv_config.timestamp_format
if config.timestamp_column is not None:
fmt = config.timestamp_format
if fmt is None:
msg = "Must specify timestamp_format if timestamp_column is specified"
raise ValueError(msg)
if csv_config.timestamp_column not in data.columns:
if config.timestamp_column not in data.columns:
log.warning(
"timestamp_column %s not found in csv file %s",
csv_config.timestamp_column,
config.timestamp_column,
path,
)
else:
data["timestamp"] = pd.to_datetime(
data[csv_config.timestamp_column], format=fmt
data[config.timestamp_column], format=fmt
)
# TODO: Theres probably a less gross way to do this

View File

@ -12,7 +12,6 @@ import pandas as pd
from graphrag.config.enums import InputType
from graphrag.config.models.input_config import InputConfig
from graphrag.index.config.input import PipelineInputConfig
from graphrag.index.input.csv import input_type as csv
from graphrag.index.input.csv import load as load_csv
from graphrag.index.input.text import input_type as text
@ -30,7 +29,7 @@ loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = {
async def create_input(
config: PipelineInputConfig | InputConfig,
config: InputConfig,
progress_reporter: ProgressLogger | None = None,
root_dir: str | None = None,
) -> pd.DataFrame:

View File

@ -10,7 +10,7 @@ from typing import Any
import pandas as pd
from graphrag.index.config.input import PipelineInputConfig
from graphrag.config.models.input_config import InputConfig
from graphrag.index.utils.hashing import gen_sha512_hash
from graphrag.logger.base import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
@ -23,7 +23,7 @@ log = logging.getLogger(__name__)
async def load(
config: PipelineInputConfig,
config: InputConfig,
progress: ProgressLogger | None,
storage: PipelineStorage,
) -> pd.DataFrame:

View File

@ -19,11 +19,12 @@ from fnllm.openai import (
create_openai_embeddings_llm,
)
from fnllm.openai.types.chat.parameters import OpenAIChatParameters
from pydantic import TypeAdapter
import graphrag.config.defaults as defs
from graphrag.config.enums import LLMType
from graphrag.config.models.llm_parameters import LLMParameters
from graphrag.config.models.language_model_config import (
LanguageModelConfig, # noqa: TC001
)
from graphrag.index.llm.manager import ChatLLMSingleton, EmbeddingsLLMSingleton
from .mock_llm import MockChatLLM
@ -93,17 +94,9 @@ def create_cache(cache: PipelineCache | None, name: str) -> LLMCache | None:
return GraphRagLLMCache(cache).child(name)
def read_llm_params(llm_args: dict[str, Any]) -> LLMParameters:
"""Read the LLM parameters from the arguments."""
if llm_args == {}:
msg = "LLM arguments are required"
raise ValueError(msg)
return TypeAdapter(LLMParameters).validate_python(llm_args)
def load_llm(
name: str,
config: LLMParameters,
config: LanguageModelConfig,
*,
callbacks: WorkflowCallbacks,
cache: PipelineCache | None,
@ -133,7 +126,7 @@ def load_llm(
def load_llm_embeddings(
name: str,
llm_config: LLMParameters,
llm_config: LanguageModelConfig,
*,
callbacks: WorkflowCallbacks,
cache: PipelineCache | None,
@ -174,7 +167,7 @@ def _create_error_handler(callbacks: WorkflowCallbacks) -> ErrorHandlerFn:
def _load_openai_chat_llm(
on_error: ErrorHandlerFn,
cache: LLMCache,
config: LLMParameters,
config: LanguageModelConfig,
azure=False,
):
return _create_openai_chat_llm(
@ -187,7 +180,7 @@ def _load_openai_chat_llm(
def _load_openai_embeddings_llm(
on_error: ErrorHandlerFn,
cache: LLMCache,
config: LLMParameters,
config: LanguageModelConfig,
azure=False,
):
return _create_openai_embeddings_llm(
@ -197,8 +190,8 @@ def _load_openai_embeddings_llm(
)
def _create_openai_config(config: LLMParameters, azure: bool) -> OpenAIConfig:
encoding_model = config.encoding_model or defs.ENCODING_MODEL
def _create_openai_config(config: LanguageModelConfig, azure: bool) -> OpenAIConfig:
encoding_model = config.encoding_model
json_strategy = (
JsonStrategy.VALID if config.model_supports_json else JsonStrategy.LOOSE
)
@ -253,19 +246,19 @@ def _create_openai_config(config: LLMParameters, azure: bool) -> OpenAIConfig:
def _load_azure_openai_chat_llm(
on_error: ErrorHandlerFn, cache: LLMCache, config: LLMParameters
on_error: ErrorHandlerFn, cache: LLMCache, config: LanguageModelConfig
):
return _load_openai_chat_llm(on_error, cache, config, True)
def _load_azure_openai_embeddings_llm(
on_error: ErrorHandlerFn, cache: LLMCache, config: LLMParameters
on_error: ErrorHandlerFn, cache: LLMCache, config: LanguageModelConfig
):
return _load_openai_embeddings_llm(on_error, cache, config, True)
def _load_static_response(
_on_error: ErrorHandlerFn, _cache: PipelineCache, config: LLMParameters
_on_error: ErrorHandlerFn, _cache: PipelineCache, config: LanguageModelConfig
) -> ChatLLM:
if config.responses is None:
msg = "Static response LLM requires responses"

View File

@ -12,8 +12,8 @@ import pandas as pd
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.embeddings import create_collection_name
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy
from graphrag.utils.embeddings import create_collection_name
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
from graphrag.vector_stores.factory import VectorStoreFactory
@ -87,23 +87,23 @@ async def embed_text(
embedding_name, vector_store_config
)
return await _text_embed_with_vector_store(
input,
callbacks,
cache,
embed_column,
strategy,
vector_store,
vector_store_workflow_config,
input=input,
callbacks=callbacks,
cache=cache,
embed_column=embed_column,
strategy=strategy,
vector_store=vector_store,
vector_store_config=vector_store_workflow_config,
id_column=id_column,
title_column=title_column,
)
return await _text_embed_in_memory(
input,
callbacks,
cache,
embed_column,
strategy,
input=input,
callbacks=callbacks,
cache=cache,
embed_column=embed_column,
strategy=strategy,
)
@ -176,12 +176,7 @@ async def _text_embed_with_vector_store(
texts: list[str] = batch[embed_column].to_numpy().tolist()
titles: list[str] = batch[title].to_numpy().tolist()
ids: list[str] = batch[id_column].to_numpy().tolist()
result = await strategy_exec(
texts,
callbacks,
cache,
strategy_args,
)
result = await strategy_exec(texts, callbacks, cache, strategy_args)
if result.embeddings:
embeddings = [
embedding for embedding in result.embeddings if embedding is not None

View File

@ -9,12 +9,10 @@ from typing import Any
import numpy as np
from fnllm import EmbeddingsLLM
from pydantic import TypeAdapter
import graphrag.config.defaults as defs
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.llm_parameters import LLMParameters
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.index.llm.load_llm import load_llm_embeddings
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingResult
from graphrag.index.text_splitting.text_splitting import TokenTextSplitter
@ -34,9 +32,10 @@ async def run(
if is_null(input):
return TextEmbeddingResult(embeddings=None)
llm_config = TypeAdapter(LLMParameters).validate_python(args.get("llm", {}))
batch_size = args.get("batch_size", 16)
batch_max_tokens = args.get("batch_max_tokens", 8191)
llm_config = args["llm"]
llm_config = LanguageModelConfig(**args["llm"])
splitter = _get_splitter(llm_config, batch_max_tokens)
llm = _get_llm(llm_config, callbacks, cache)
semaphore: asyncio.Semaphore = asyncio.Semaphore(args.get("num_threads", 4))
@ -66,15 +65,17 @@ async def run(
return TextEmbeddingResult(embeddings=embeddings)
def _get_splitter(config: LLMParameters, batch_max_tokens: int) -> TokenTextSplitter:
def _get_splitter(
config: LanguageModelConfig, batch_max_tokens: int
) -> TokenTextSplitter:
return TokenTextSplitter(
encoding_name=config.encoding_model or defs.ENCODING_MODEL,
encoding_name=config.encoding_model,
chunk_size=batch_max_tokens,
)
def _get_llm(
config: LLMParameters,
config: LanguageModelConfig,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
) -> EmbeddingsLLM:

View File

@ -14,7 +14,8 @@ import graphrag.config.defaults as defs
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
from graphrag.index.llm.load_llm import load_llm, read_llm_params
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.index.llm.load_llm import load_llm
from graphrag.index.operations.extract_covariates.claim_extractor import ClaimExtractor
from graphrag.index.operations.extract_covariates.typing import (
Covariate,
@ -52,7 +53,12 @@ async def extract_covariates(
async def run_strategy(row):
text = row[column]
result = await run_claim_extraction(
text, entity_types, resolved_entities_map, callbacks, cache, strategy_config
input=text,
entity_types=entity_types,
resolved_entities_map=resolved_entities_map,
callbacks=callbacks,
cache=cache,
strategy_config=strategy_config,
)
return [
create_row_from_claim_data(row, item, covariate_type)
@ -83,8 +89,13 @@ async def run_claim_extraction(
strategy_config: dict[str, Any],
) -> CovariateExtractionResult:
"""Run the Claim extraction chain."""
llm_config = read_llm_params(strategy_config.get("llm", {}))
llm = load_llm("claim_extraction", llm_config, callbacks=callbacks, cache=cache)
llm_config = LanguageModelConfig(**strategy_config["llm"])
llm = load_llm(
"claim_extraction",
llm_config,
callbacks=callbacks,
cache=cache,
)
extraction_prompt = strategy_config.get("extraction_prompt")
max_gleanings = strategy_config.get("max_gleanings", defs.CLAIM_MAX_GLEANINGS)
tuple_delimiter = strategy_config.get("tuple_delimiter")

View File

@ -9,7 +9,8 @@ from fnllm import ChatLLM
import graphrag.config.defaults as defs
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.llm.load_llm import load_llm, read_llm_params
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.index.llm.load_llm import load_llm
from graphrag.index.operations.extract_entities.graph_extractor import GraphExtractor
from graphrag.index.operations.extract_entities.typing import (
Document,
@ -27,8 +28,13 @@ async def run_graph_intelligence(
args: StrategyConfig,
) -> EntityExtractionResult:
"""Run the graph intelligence entity extraction strategy."""
llm_config = read_llm_params(args.get("llm", {}))
llm = load_llm("entity_extraction", llm_config, callbacks=callbacks, cache=cache)
llm_config = LanguageModelConfig(**args["llm"])
llm = load_llm(
"entity_extraction",
llm_config,
callbacks=callbacks,
cache=cache,
)
return await run_extract_entities(llm, docs, entity_types, callbacks, args)

View File

@ -10,7 +10,8 @@ from fnllm import ChatLLM
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.llm.load_llm import load_llm, read_llm_params
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.index.llm.load_llm import load_llm
from graphrag.index.operations.summarize_communities.community_reports_extractor.community_reports_extractor import (
CommunityReportsExtractor,
)
@ -33,8 +34,13 @@ async def run_graph_intelligence(
args: StrategyConfig,
) -> CommunityReport | None:
"""Run the graph intelligence entity extraction strategy."""
llm_config = read_llm_params(args.get("llm", {}))
llm = load_llm("community_reporting", llm_config, callbacks=callbacks, cache=cache)
llm_config = LanguageModelConfig(**args["llm"])
llm = load_llm(
"community_reporting",
llm_config,
callbacks=callbacks,
cache=cache,
)
return await _run_extractor(llm, community, input, level, args, callbacks)

View File

@ -93,7 +93,12 @@ async def _generate_report(
) -> CommunityReport | None:
"""Generate a report for a single community."""
return await runner(
community_id, community_context, community_level, callbacks, cache, strategy
community_id,
community_context,
community_level,
callbacks,
cache,
strategy,
)

View File

@ -7,7 +7,8 @@ from fnllm import ChatLLM
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.llm.load_llm import load_llm, read_llm_params
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.index.llm.load_llm import load_llm
from graphrag.index.operations.summarize_descriptions.description_summary_extractor import (
SummarizeExtractor,
)
@ -25,9 +26,12 @@ async def run_graph_intelligence(
args: StrategyConfig,
) -> SummarizedDescriptionResult:
"""Run the graph intelligence entity extraction strategy."""
llm_config = read_llm_params(args.get("llm", {}))
llm_config = LanguageModelConfig(**args["llm"])
llm = load_llm(
"summarize_descriptions", llm_config, callbacks=callbacks, cache=cache
"summarize_descriptions",
llm_config,
callbacks=callbacks,
cache=cache,
)
return await run_summarize_descriptions(llm, id, descriptions, callbacks, args)

View File

@ -136,11 +136,7 @@ async def summarize_descriptions(
):
async with semaphore:
results = await strategy_exec(
id,
descriptions,
callbacks,
cache,
strategy_config,
id, descriptions, callbacks, cache, strategy_config
)
ticker(1)
return results

View File

@ -23,10 +23,11 @@ ItemType = TypeVar("ItemType")
class ParallelizationError(ValueError):
"""Exception for invalid parallel processing."""
def __init__(self, num_errors: int):
super().__init__(
f"{num_errors} Errors occurred while running parallel transformation, could not complete!"
)
def __init__(self, num_errors: int, example: str | None = None):
msg = f"{num_errors} Errors occurred while running parallel transformation, could not complete!"
if example:
msg += f"\nExample error: {example}"
super().__init__(msg)
async def derive_from_rows(
@ -153,6 +154,6 @@ async def _derive_from_rows_base(
callbacks.error("parallel transformation error", error, stack)
if len(errors) > 0:
raise ParallelizationError(len(errors))
raise ParallelizationError(len(errors), errors[0][1])
return result

View File

@ -55,16 +55,14 @@ async def run_workflows(
cache: PipelineCache | None = None,
callbacks: list[WorkflowCallbacks] | None = None,
logger: ProgressLogger | None = None,
run_id: str | None = None,
is_update_run: bool = False,
) -> AsyncIterable[PipelineRunResult]:
"""Run all workflows using a simplified pipeline."""
run_id = run_id or time.strftime("%Y%m%d-%H%M%S")
root_dir = config.root_dir or ""
root_dir = config.root_dir
progress_logger = logger or NullProgressLogger()
callbacks = callbacks or [ConsoleWorkflowCallbacks()]
callback_chain = create_callback_chain(callbacks, progress_logger)
storage_config = config.storage.model_dump() # type: ignore
storage_config = config.output.model_dump() # type: ignore
storage = StorageFactory().create_storage(
storage_type=storage_config["type"], # type: ignore
kwargs=storage_config,
@ -81,7 +79,7 @@ async def run_workflows(
if is_update_run:
progress_logger.info("Running incremental indexing.")
update_storage_config = config.update_index_storage.model_dump() # type: ignore
update_storage_config = config.update_index_output.model_dump() # type: ignore
update_index_storage = StorageFactory().create_storage(
storage_type=update_storage_config["type"], # type: ignore
kwargs=update_storage_config,

View File

@ -110,8 +110,11 @@ async def _run_entity_summarization(
pd.DataFrame
The updated entities dataframe with summarized descriptions.
"""
summarization_llm_settings = config.get_language_model_config(
config.summarize_descriptions.model_id
)
summarization_strategy = config.summarize_descriptions.resolved_strategy(
config.root_dir,
config.root_dir, summarization_llm_settings
)
# Prepare tasks for async summarization where needed
@ -120,7 +123,11 @@ async def _run_entity_summarization(
if isinstance(description, list) and len(description) > 1:
# Run entity summarization asynchronously
result = await run_entity_summarization(
row["title"], description, callbacks, cache, summarization_strategy
row["title"],
description,
callbacks,
cache,
summarization_strategy,
)
return result.description
# Handle case where description is a single-item list or not a list

View File

@ -10,8 +10,8 @@ import pandas as pd
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.embeddings import get_embedded_fields, get_embedding_settings
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings
from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings
from graphrag.index.update.communities import (
_merge_and_resolve_nodes,
@ -150,7 +150,7 @@ async def update_dataframe_outputs(
# Generate text embeddings
progress_logger.info("Updating Text Embeddings")
embedded_fields = get_embedded_fields(config)
text_embed = get_embedding_settings(config.embeddings)
text_embed = get_embedding_settings(config)
await generate_text_embeddings(
final_documents=final_documents_df,
final_relationships=merged_relationships_df,

View File

@ -15,9 +15,11 @@ from graphrag.logger.print_progress import ProgressLogger
def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) -> None:
"""Validate config file for LLM deployment name typos."""
# Validate Chat LLM configs
# TODO: Replace default_chat_model with a way to select the model
default_llm_settings = parameters.get_language_model_config("default_chat_model")
llm = load_llm(
"test-llm",
parameters.llm,
name="test-llm",
config=default_llm_settings,
callbacks=NoopWorkflowCallbacks(),
cache=None,
)
@ -29,9 +31,12 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
sys.exit(1)
# Validate Embeddings LLM configs
embedding_llm_settings = parameters.get_language_model_config(
parameters.embeddings.model_id
)
embed_llm = load_llm_embeddings(
"test-embed-llm",
parameters.embeddings.llm,
name="test-embed-llm",
llm_config=embedding_llm_settings,
callbacks=NoopWorkflowCallbacks(),
cache=None,
)

View File

@ -33,19 +33,25 @@ async def run_workflow(
claims = await load_table_from_storage(
"create_final_covariates", context.storage
)
async_mode = config.community_reports.async_mode
num_threads = config.community_reports.parallelization.num_threads
summarization_strategy = config.community_reports.resolved_strategy(config.root_dir)
community_reports_llm_settings = config.get_language_model_config(
config.community_reports.model_id
)
async_mode = community_reports_llm_settings.async_mode
num_threads = community_reports_llm_settings.parallelization_num_threads
summarization_strategy = config.community_reports.resolved_strategy(
config.root_dir, community_reports_llm_settings
)
output = await create_final_community_reports(
nodes,
edges,
entities,
communities,
claims,
callbacks,
context.cache,
summarization_strategy,
nodes_input=nodes,
edges_input=edges,
entities=entities,
communities=communities,
claims_input=claims,
callbacks=callbacks,
cache=context.cache,
summarization_strategy=summarization_strategy,
async_mode=async_mode,
num_threads=num_threads,
)

View File

@ -26,12 +26,15 @@ async def run_workflow(
"create_base_text_units", context.storage
)
claim_extraction_llm_settings = config.get_language_model_config(
config.claim_extraction.model_id
)
extraction_strategy = config.claim_extraction.resolved_strategy(
config.root_dir, config.encoding_model
config.root_dir, claim_extraction_llm_settings
)
async_mode = config.claim_extraction.async_mode
num_threads = config.claim_extraction.parallelization.num_threads
async_mode = claim_extraction_llm_settings.async_mode
num_threads = claim_extraction_llm_settings.parallelization_num_threads
output = await create_final_covariates(
text_units,

View File

@ -28,24 +28,28 @@ async def run_workflow(
"create_base_text_units", context.storage
)
extraction_strategy = config.entity_extraction.resolved_strategy(
config.root_dir, config.encoding_model
entity_extraction_llm_settings = config.get_language_model_config(
config.entity_extraction.model_id
)
extraction_num_threads = config.entity_extraction.parallelization.num_threads
extraction_async_mode = config.entity_extraction.async_mode
extraction_strategy = config.entity_extraction.resolved_strategy(
config.root_dir, entity_extraction_llm_settings
)
extraction_num_threads = entity_extraction_llm_settings.parallelization_num_threads
extraction_async_mode = entity_extraction_llm_settings.async_mode
entity_types = config.entity_extraction.entity_types
summarization_llm_settings = config.get_language_model_config(
config.summarize_descriptions.model_id
)
summarization_strategy = config.summarize_descriptions.resolved_strategy(
config.root_dir,
)
summarization_num_threads = (
config.summarize_descriptions.parallelization.num_threads
config.root_dir, summarization_llm_settings
)
summarization_num_threads = summarization_llm_settings.parallelization_num_threads
base_entity_nodes, base_relationship_edges = await extract_graph(
text_units,
callbacks,
context.cache,
text_units=text_units,
callbacks=callbacks,
cache=context.cache,
extraction_strategy=extraction_strategy,
extraction_num_threads=extraction_num_threads,
extraction_async_mode=extraction_async_mode,

View File

@ -6,8 +6,8 @@
import pandas as pd
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.embeddings import get_embedded_fields, get_embedding_settings
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings
from graphrag.index.context import PipelineRunContext
from graphrag.index.flows.generate_text_embeddings import (
generate_text_embeddings,
@ -40,7 +40,7 @@ async def run_workflow(
)
embedded_fields = get_embedded_fields(config)
text_embed = get_embedding_settings(config.embeddings)
text_embed = get_embedding_settings(config)
await generate_text_embeddings(
final_documents=final_documents,

View File

@ -16,3 +16,4 @@ MAX_TOKEN_COUNT = 2000
MIN_CHUNK_SIZE = 200
N_SUBSET_MAX = 300
MIN_CHUNK_OVERLAP = 0
PROMPT_TUNING_MODEL_ID = "default_chat_model"

View File

@ -6,12 +6,10 @@
import numpy as np
import pandas as pd
from fnllm import ChatLLM
from pydantic import TypeAdapter
import graphrag.config.defaults as defs
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.llm_parameters import LLMParameters
from graphrag.index.input.factory import create_input
from graphrag.index.llm.load_llm import load_llm_embeddings
from graphrag.index.operations.chunk_text.chunk_text import chunk_text
@ -60,8 +58,8 @@ async def load_docs_in_chunks(
k: int = K,
) -> list[str]:
"""Load docs into chunks for generating prompts."""
llm_config = TypeAdapter(LLMParameters).validate_python(
config.embeddings.resolved_strategy()["llm"]
embeddings_llm_settings = config.get_language_model_config(
config.embeddings.model_id
)
dataset = await create_input(config.input, logger, root)
@ -96,8 +94,8 @@ async def load_docs_in_chunks(
msg = "k must be an integer > 0"
raise ValueError(msg)
embedding_llm = load_llm_embeddings(
"prompt_tuning_embeddings",
llm_config,
name="prompt_tuning_embeddings",
llm_config=embeddings_llm_settings,
callbacks=NoopWorkflowCallbacks(),
cache=None,
)

View File

@ -45,9 +45,10 @@ def get_local_search_engine(
system_prompt: str | None = None,
) -> LocalSearch:
"""Create a local search engine based on data + configuration."""
default_llm_settings = config.get_language_model_config("default_chat_model")
llm = get_llm(config)
text_embedder = get_text_embedder(config)
token_encoder = tiktoken.get_encoding(config.encoding_model)
token_encoder = tiktoken.get_encoding(default_llm_settings.encoding_model)
ls_config = config.local_search
@ -102,7 +103,11 @@ def get_global_search_engine(
general_knowledge_inclusion_prompt: str | None = None,
) -> GlobalSearch:
"""Create a global search engine based on data + configuration."""
token_encoder = tiktoken.get_encoding(config.encoding_model)
# TODO: Global search should select model based on config??
default_llm_settings = config.get_language_model_config("default_chat_model")
# Here we get encoding based on specified encoding name
token_encoder = tiktoken.get_encoding(default_llm_settings.encoding_model)
gs_config = config.global_search
dynamic_community_selection_kwargs = {}
@ -111,7 +116,8 @@ def get_global_search_engine(
dynamic_community_selection_kwargs.update({
"llm": get_llm(config),
"token_encoder": tiktoken.encoding_for_model(config.llm.model),
# And here we get encoding based on model
"token_encoder": tiktoken.encoding_for_model(default_llm_settings.model),
"keep_parent": gs_config.dynamic_search_keep_parent,
"num_repeats": gs_config.dynamic_search_num_repeats,
"use_summary": gs_config.dynamic_search_use_summary,
@ -178,9 +184,10 @@ def get_drift_search_engine(
reduce_system_prompt: str | None = None,
) -> DRIFTSearch:
"""Create a local search engine based on data + configuration."""
default_llm_settings = config.get_language_model_config("default_chat_model")
llm = get_llm(config)
text_embedder = get_text_embedder(config)
token_encoder = tiktoken.get_encoding(config.encoding_model)
token_encoder = tiktoken.get_encoding(default_llm_settings.encoding_model)
return DRIFTSearch(
llm=llm,
@ -208,9 +215,10 @@ def get_basic_search_engine(
system_prompt: str | None = None,
) -> BasicSearch:
"""Create a basic search engine based on data + configuration."""
default_llm_settings = config.get_language_model_config("default_chat_model")
llm = get_llm(config)
text_embedder = get_text_embedder(config)
token_encoder = tiktoken.get_encoding(config.encoding_model)
token_encoder = tiktoken.get_encoding(default_llm_settings.encoding_model)
ls_config = config.basic_search

View File

@ -14,61 +14,65 @@ from graphrag.query.llm.oai.typing import OpenaiApiType
def get_llm(config: GraphRagConfig) -> ChatOpenAI:
"""Get the LLM client."""
is_azure_client = config.llm.type == LLMType.AzureOpenAIChat
debug_llm_key = config.llm.api_key or ""
default_llm_settings = config.get_language_model_config("default_chat_model")
is_azure_client = default_llm_settings.type == LLMType.AzureOpenAIChat
debug_llm_key = default_llm_settings.api_key or ""
llm_debug_info = {
**config.llm.model_dump(),
**default_llm_settings.model_dump(),
"api_key": f"REDACTED,len={len(debug_llm_key)}",
}
audience = (
config.llm.audience
if config.llm.audience
default_llm_settings.audience
if default_llm_settings.audience
else "https://cognitiveservices.azure.com/.default"
)
print(f"creating llm client with {llm_debug_info}") # noqa T201
return ChatOpenAI(
api_key=config.llm.api_key,
api_key=default_llm_settings.api_key,
azure_ad_token_provider=(
get_bearer_token_provider(DefaultAzureCredential(), audience)
if is_azure_client and not config.llm.api_key
if is_azure_client and not default_llm_settings.api_key
else None
),
api_base=config.llm.api_base,
organization=config.llm.organization,
model=config.llm.model,
api_base=default_llm_settings.api_base,
organization=default_llm_settings.organization,
model=default_llm_settings.model,
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
deployment_name=config.llm.deployment_name,
api_version=config.llm.api_version,
max_retries=config.llm.max_retries,
request_timeout=config.llm.request_timeout,
deployment_name=default_llm_settings.deployment_name,
api_version=default_llm_settings.api_version,
max_retries=default_llm_settings.max_retries,
request_timeout=default_llm_settings.request_timeout,
)
def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
"""Get the LLM client for embeddings."""
is_azure_client = config.embeddings.llm.type == LLMType.AzureOpenAIEmbedding
debug_embedding_api_key = config.embeddings.llm.api_key or ""
embeddings_llm_settings = config.get_language_model_config(
config.embeddings.model_id
)
is_azure_client = embeddings_llm_settings.type == LLMType.AzureOpenAIEmbedding
debug_embedding_api_key = embeddings_llm_settings.api_key or ""
llm_debug_info = {
**config.embeddings.llm.model_dump(),
**embeddings_llm_settings.model_dump(),
"api_key": f"REDACTED,len={len(debug_embedding_api_key)}",
}
if config.embeddings.llm.audience is None:
if embeddings_llm_settings.audience is None:
audience = "https://cognitiveservices.azure.com/.default"
else:
audience = config.embeddings.llm.audience
audience = embeddings_llm_settings.audience
print(f"creating embedding llm client with {llm_debug_info}") # noqa T201
return OpenAIEmbedding(
api_key=config.embeddings.llm.api_key,
api_key=embeddings_llm_settings.api_key,
azure_ad_token_provider=(
get_bearer_token_provider(DefaultAzureCredential(), audience)
if is_azure_client and not config.embeddings.llm.api_key
if is_azure_client and not embeddings_llm_settings.api_key
else None
),
api_base=config.embeddings.llm.api_base,
organization=config.llm.organization,
api_base=embeddings_llm_settings.api_base,
organization=embeddings_llm_settings.organization,
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
model=config.embeddings.llm.model,
deployment_name=config.embeddings.llm.deployment_name,
api_version=config.embeddings.llm.api_version,
max_retries=config.embeddings.llm.max_retries,
model=embeddings_llm_settings.model,
deployment_name=embeddings_llm_settings.deployment_name,
api_version=embeddings_llm_settings.api_version,
max_retries=embeddings_llm_settings.max_retries,
)

View File

@ -11,7 +11,6 @@ from typing import Any
import tiktoken
from tqdm.asyncio import tqdm_asyncio
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
@ -35,7 +34,6 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
self,
llm: ChatOpenAI,
context_builder: DRIFTSearchContextBuilder,
config: DRIFTSearchConfig | None = None,
token_encoder: tiktoken.Encoding | None = None,
query_state: QueryState | None = None,
):
@ -51,12 +49,13 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
"""
super().__init__(llm, context_builder, token_encoder)
self.config = config or DRIFTSearchConfig()
self.context_builder = context_builder
self.token_encoder = token_encoder
self.query_state = query_state or QueryState()
self.primer = DRIFTPrimer(
config=self.config, chat_llm=llm, token_encoder=token_encoder
config=self.context_builder.config,
chat_llm=llm,
token_encoder=token_encoder,
)
self.local_search = self.init_local_search()
@ -69,21 +68,21 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
LocalSearch: An instance of the LocalSearch class with the configured parameters.
"""
local_context_params = {
"text_unit_prop": self.config.local_search_text_unit_prop,
"community_prop": self.config.local_search_community_prop,
"top_k_mapped_entities": self.config.local_search_top_k_mapped_entities,
"top_k_relationships": self.config.local_search_top_k_relationships,
"text_unit_prop": self.context_builder.config.local_search_text_unit_prop,
"community_prop": self.context_builder.config.local_search_community_prop,
"top_k_mapped_entities": self.context_builder.config.local_search_top_k_mapped_entities,
"top_k_relationships": self.context_builder.config.local_search_top_k_relationships,
"include_entity_rank": True,
"include_relationship_weight": True,
"include_community_rank": False,
"return_candidate_context": False,
"embedding_vectorstore_key": EntityVectorStoreKey.ID,
"max_tokens": self.config.local_search_max_data_tokens,
"max_tokens": self.context_builder.config.local_search_max_data_tokens,
}
llm_params = {
"max_tokens": self.config.local_search_llm_max_gen_tokens,
"temperature": self.config.local_search_temperature,
"max_tokens": self.context_builder.config.local_search_llm_max_gen_tokens,
"temperature": self.context_builder.config.local_search_temperature,
"response_format": {"type": "json_object"},
}
@ -220,13 +219,15 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
# Main loop
epochs = 0
llm_call_offset = 0
while epochs < self.config.n:
while epochs < self.context_builder.config.n:
actions = self.query_state.rank_incomplete_actions()
if len(actions) == 0:
log.info("No more actions to take. Exiting DRIFT loop.")
break
actions = actions[: self.config.drift_k_followups]
llm_call_offset += len(actions) - self.config.drift_k_followups
actions = actions[: self.context_builder.config.drift_k_followups]
llm_call_offset += (
len(actions) - self.context_builder.config.drift_k_followups
)
# Process actions
results = await self.asearch_step(
global_query=query, search_engine=self.local_search, actions=actions
@ -260,8 +261,8 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
llm_calls=llm_calls,
prompt_tokens=prompt_tokens,
output_tokens=output_tokens,
max_tokens=self.config.reduce_max_tokens,
temperature=self.config.reduce_temperature,
max_tokens=self.context_builder.config.reduce_max_tokens,
temperature=self.context_builder.config.reduce_temperature,
)
return SearchResult(
@ -317,8 +318,8 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
async for resp in self._reduce_response_streaming(
responses=result.response,
query=query,
max_tokens=self.config.reduce_max_tokens,
temperature=self.config.reduce_temperature,
max_tokens=self.context_builder.config.reduce_max_tokens,
temperature=self.context_builder.config.reduce_temperature,
):
yield resp

View File

@ -7,7 +7,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from graphrag.config.enums import StorageType
from graphrag.config.enums import OutputType
from graphrag.storage.blob_pipeline_storage import create_blob_storage
from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage
from graphrag.storage.file_pipeline_storage import create_file_storage
@ -35,17 +35,17 @@ class StorageFactory:
@classmethod
def create_storage(
cls, storage_type: StorageType | str, kwargs: dict
cls, storage_type: OutputType | str, kwargs: dict
) -> PipelineStorage:
"""Create or get a storage object from the provided type."""
match storage_type:
case StorageType.blob:
case OutputType.blob:
return create_blob_storage(**kwargs)
case StorageType.cosmosdb:
case OutputType.cosmosdb:
return create_cosmosdb_storage(**kwargs)
case StorageType.file:
case OutputType.file:
return create_file_storage(**kwargs)
case StorageType.memory:
case OutputType.memory:
return MemoryPipelineStorage()
case _:
if storage_type in cls.storage_types:

View File

@ -1,25 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Utilities for working with embeddings stores."""
from graphrag.index.config.embeddings import all_embeddings
def create_collection_name(
container_name: str, embedding_name: str, validate: bool = True
) -> str:
"""
Create a collection name for the embedding store.
Within any given vector store, we can have multiple sets of embeddings organized into projects.
The `container` param is used for this partitioning, and is added as a prefix to the collection name for differentiation.
The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings
Note that we use dot notation in our names, but many vector stores do not support this - so we convert to dashes.
"""
if validate and embedding_name not in all_embeddings:
msg = f"Invalid embedding name: {embedding_name}"
raise KeyError(msg)
return f"{container_name}-{embedding_name}".replace(".", "-")

View File

@ -7,7 +7,6 @@
1,
2500
],
"subworkflows": 1,
"max_runtime": 150,
"expected_artifacts": 0
},
@ -16,7 +15,6 @@
1,
2500
],
"subworkflows": 1,
"max_runtime": 150,
"expected_artifacts": 0
},
@ -25,7 +23,6 @@
1,
2500
],
"subworkflows": 1,
"max_runtime": 300,
"expected_artifacts": 0
},
@ -38,7 +35,6 @@
"type",
"description"
],
"subworkflows": 1,
"max_runtime": 300,
"expected_artifacts": 1
},
@ -47,7 +43,6 @@
1,
6000
],
"subworkflows": 1,
"max_runtime": 150,
"expected_artifacts": 1
},
@ -61,7 +56,6 @@
"x",
"y"
],
"subworkflows": 1,
"max_runtime": 150,
"expected_artifacts": 1
},
@ -70,7 +64,6 @@
1,
2500
],
"subworkflows": 1,
"max_runtime": 150,
"expected_artifacts": 1
},
@ -90,7 +83,6 @@
"period",
"size"
],
"subworkflows": 1,
"max_runtime": 300,
"expected_artifacts": 1
},
@ -103,7 +95,6 @@
"relationship_ids",
"entity_ids"
],
"subworkflows": 1,
"max_runtime": 150,
"expected_artifacts": 1
},
@ -112,7 +103,6 @@
1,
2500
],
"subworkflows": 1,
"max_runtime": 150,
"expected_artifacts": 1
},
@ -121,7 +111,6 @@
1,
2500
],
"subworkflows": 1,
"max_runtime": 150,
"expected_artifacts": 1
}

View File

@ -1,24 +1,50 @@
models:
default_chat_model:
type: ${GRAPHRAG_LLM_TYPE}
api_key: ${GRAPHRAG_API_KEY}
api_base: ${GRAPHRAG_API_BASE}
api_version: ${GRAPHRAG_API_VERSION}
deployment_name: ${GRAPHRAG_LLM_DEPLOYMENT_NAME}
model: ${GRAPHRAG_LLM_MODEL}
tokens_per_minute: ${GRAPHRAG_LLM_TPM}
requests_per_minute: ${GRAPHRAG_LLM_RPM}
model_supports_json: true
parallelization_num_threads: 50
parallelization_stagger: 0.3
async_mode: threaded
default_embedding_model:
type: ${GRAPHRAG_EMBEDDING_TYPE}
api_key: ${GRAPHRAG_API_KEY}
api_base: ${GRAPHRAG_API_BASE}
api_version: ${GRAPHRAG_API_VERSION}
deployment_name: ${GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME}
model: ${GRAPHRAG_EMBEDDING_MODEL}
tokens_per_minute: ${GRAPHRAG_EMBEDDING_TPM}
requests_per_minute: ${GRAPHRAG_EMBEDDING_RPM}
parallelization_num_threads: 50
parallelization_stagger: 0.3
async_mode: threaded
vector_store:
type: "lancedb"
db_uri: "./tests/fixtures/min-csv/lancedb"
container_name: "lancedb_ci"
overwrite: True
input:
file_type: csv
file_pattern: ".*\\.csv$$"
embeddings:
vector_store:
type: "lancedb"
db_uri: "./tests/fixtures/min-csv/lancedb"
container_name: "lancedb_ci"
overwrite: True
model_id: "default_embedding_model"
storage:
type: file # or blob
base_dir: "output/${timestamp}/artifacts"
# connection_string: <azure_blob_storage_connection_string>
# container_name: <azure_blob_storage_container_name>
type: file
base_dir: "output"
reporting:
type: file # or console, blob
base_dir: "output/${timestamp}/reports"
# connection_string: <azure_blob_storage_connection_string>
# container_name: <azure_blob_storage_container_name>
type: file
base_dir: "logs"
snapshots:
embeddings: True

View File

@ -7,7 +7,6 @@
1,
2500
],
"subworkflows": 1,
"max_runtime": 150,
"expected_artifacts": 0
},
@ -16,7 +15,6 @@
1,
2500
],
"subworkflows": 1,
"max_runtime": 150,
"expected_artifacts": 0
},
@ -25,7 +23,6 @@
1,
2500
],
"subworkflows": 1,
"max_runtime": 300,
"expected_artifacts": 0
},
@ -43,7 +40,6 @@
"end_date",
"source_text"
],
"subworkflows": 1,
"max_runtime": 300,
"expected_artifacts": 1
},
@ -56,7 +52,6 @@
"type",
"description"
],
"subworkflows": 1,
"max_runtime": 300,
"expected_artifacts": 1
},
@ -65,7 +60,6 @@
1,
6000
],
"subworkflows": 1,
"max_runtime": 150,
"expected_artifacts": 1
},
@ -79,7 +73,6 @@
"x",
"y"
],
"subworkflows": 1,
"max_runtime": 150,
"expected_artifacts": 1
},
@ -88,7 +81,6 @@
1,
2500
],
"subworkflows": 1,
"max_runtime": 150,
"expected_artifacts": 1
},
@ -108,7 +100,6 @@
"period",
"size"
],
"subworkflows": 1,
"max_runtime": 300,
"expected_artifacts": 1
},
@ -121,7 +112,6 @@
"relationship_ids",
"entity_ids"
],
"subworkflows": 1,
"max_runtime": 150,
"expected_artifacts": 1
},
@ -130,7 +120,6 @@
1,
2500
],
"subworkflows": 1,
"max_runtime": 150,
"expected_artifacts": 1
},
@ -139,7 +128,6 @@
1,
2500
],
"subworkflows": 1,
"max_runtime": 150,
"expected_artifacts": 1
}

View File

@ -1,12 +1,41 @@
models:
default_chat_model:
type: ${GRAPHRAG_LLM_TYPE}
api_key: ${GRAPHRAG_API_KEY}
api_base: ${GRAPHRAG_API_BASE}
api_version: ${GRAPHRAG_API_VERSION}
deployment_name: ${GRAPHRAG_LLM_DEPLOYMENT_NAME}
model: ${GRAPHRAG_LLM_MODEL}
tokens_per_minute: ${GRAPHRAG_LLM_TPM}
requests_per_minute: ${GRAPHRAG_LLM_RPM}
model_supports_json: true
parallelization_num_threads: 50
parallelization_stagger: 0.3
async_mode: threaded
default_embedding_model:
type: ${GRAPHRAG_EMBEDDING_TYPE}
api_key: ${GRAPHRAG_API_KEY}
api_base: ${GRAPHRAG_API_BASE}
api_version: ${GRAPHRAG_API_VERSION}
deployment_name: ${GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME}
model: ${GRAPHRAG_EMBEDDING_MODEL}
tokens_per_minute: ${GRAPHRAG_EMBEDDING_TPM}
requests_per_minute: ${GRAPHRAG_EMBEDDING_RPM}
parallelization_num_threads: 50
parallelization_stagger: 0.3
async_mode: threaded
vector_store:
type: "azure_ai_search"
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
api_key: ${AZURE_AI_SEARCH_API_KEY}
container_name: "simple_text_ci"
claim_extraction:
enabled: true
embeddings:
vector_store:
type: "azure_ai_search"
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
api_key: ${AZURE_AI_SEARCH_API_KEY}
container_name: "simple_text_ci"
model_id: "default_embedding_model"
community_reports:
prompt: "prompts/community_report.txt"
@ -15,16 +44,12 @@ community_reports:
storage:
type: file # or blob
base_dir: "output/${timestamp}/artifacts"
# connection_string: <azure_blob_storage_connection_string>
# container_name: <azure_blob_storage_container_name>
type: file
base_dir: "output"
reporting:
type: file # or console, blob
base_dir: "output/${timestamp}/reports"
# connection_string: <azure_blob_storage_connection_string>
# container_name: <azure_blob_storage_container_name>
type: file
base_dir: "logs"
snapshots:
embeddings: True

View File

@ -9,7 +9,7 @@ import sys
import pytest
from graphrag.config.enums import StorageType
from graphrag.config.enums import OutputType
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
from graphrag.storage.factory import StorageFactory
@ -29,7 +29,7 @@ def test_create_blob_storage():
"base_dir": "testbasedir",
"container_name": "testcontainer",
}
storage = StorageFactory.create_storage(StorageType.blob, kwargs)
storage = StorageFactory.create_storage(OutputType.blob, kwargs)
assert isinstance(storage, BlobPipelineStorage)
@ -44,19 +44,19 @@ def test_create_cosmosdb_storage():
"base_dir": "testdatabase",
"container_name": "testcontainer",
}
storage = StorageFactory.create_storage(StorageType.cosmosdb, kwargs)
storage = StorageFactory.create_storage(OutputType.cosmosdb, kwargs)
assert isinstance(storage, CosmosDBPipelineStorage)
def test_create_file_storage():
kwargs = {"type": "file", "base_dir": "/tmp/teststorage"}
storage = StorageFactory.create_storage(StorageType.file, kwargs)
storage = StorageFactory.create_storage(OutputType.file, kwargs)
assert isinstance(storage, FilePipelineStorage)
def test_create_memory_storage():
kwargs = {"type": "memory"}
storage = StorageFactory.create_storage(StorageType.memory, kwargs)
storage = StorageFactory.create_storage(OutputType.memory, kwargs)
assert isinstance(storage, MemoryPipelineStorage)

View File

@ -152,24 +152,12 @@ class TestIndexer:
def __assert_indexer_outputs(
self, root: Path, workflow_config: dict[str, dict[str, Any]]
):
outputs_path = root / "output"
output_entries = list(outputs_path.iterdir())
# Sort the output folders by creation time, most recent
output_entries.sort(key=lambda entry: entry.stat().st_ctime, reverse=True)
output_path = root / "output"
if not debug:
assert len(output_entries) == 1, (
f"Expected one output folder, found {len(output_entries)}"
)
output_path = output_entries[0]
assert output_path.exists(), "output folder does not exist"
artifacts = output_path / "artifacts"
assert artifacts.exists(), "artifact folder does not exist"
# Check stats for all workflow
stats = json.loads((artifacts / "stats.json").read_bytes().decode("utf-8"))
stats = json.loads((output_path / "stats.json").read_bytes().decode("utf-8"))
# Check all workflows run
expected_artifacts = 0
@ -193,7 +181,7 @@ class TestIndexer:
)
# Check artifacts
artifact_files = os.listdir(artifacts)
artifact_files = os.listdir(output_path)
# check that the number of workflows matches the number of artifacts
assert len(artifact_files) == (expected_artifacts + 3), (
@ -202,7 +190,7 @@ class TestIndexer:
for artifact in artifact_files:
if artifact.endswith(".parquet"):
output_df = pd.read_parquet(artifacts / artifact)
output_df = pd.read_parquet(output_path / artifact)
artifact_name = artifact.split(".")[0]
try:

View File

@ -0,0 +1,9 @@
models:
default_chat_model:
api_key: ${CUSTOM_API_KEY}
type: openai_chat
model: gpt-4-turbo-preview
default_embedding_model:
api_key: ${CUSTOM_API_KEY}
type: openai_embedding
model: text-embedding-3-small

View File

@ -0,0 +1,9 @@
models:
default_chat_model:
api_key: ${SOME_NON_EXISTENT_ENV_VAR}
type: openai_chat
model: gpt-4-turbo-preview
default_embedding_model:
api_key: ${SOME_NON_EXISTENT_ENV_VAR}
type: openai_embedding
model: text-embedding-3-small

View File

@ -0,0 +1,168 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
import os
from pathlib import Path
from unittest import mock
import pytest
from pydantic import ValidationError
import graphrag.config.defaults as defs
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.enums import AzureAuthType, LLMType
from graphrag.config.load_config import load_config
from tests.unit.config.utils import (
DEFAULT_EMBEDDING_MODEL_CONFIG,
DEFAULT_MODEL_CONFIG,
FAKE_API_KEY,
assert_graphrag_configs,
get_default_graphrag_config,
)
def test_missing_openai_required_api_key() -> None:
model_config_missing_api_key = {
defs.DEFAULT_CHAT_MODEL_ID: {
"type": LLMType.OpenAIChat,
"model": defs.LLM_MODEL,
},
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
}
# API Key required for OpenAIChat
with pytest.raises(ValidationError):
create_graphrag_config({"models": model_config_missing_api_key})
# API Key required for OpenAIEmbedding
model_config_missing_api_key[defs.DEFAULT_CHAT_MODEL_ID]["type"] = (
LLMType.OpenAIEmbedding
)
with pytest.raises(ValidationError):
create_graphrag_config({"models": model_config_missing_api_key})
def test_missing_azure_api_key() -> None:
model_config_missing_api_key = {
defs.DEFAULT_CHAT_MODEL_ID: {
"type": LLMType.AzureOpenAIChat,
"azure_auth_type": AzureAuthType.APIKey,
"model": defs.LLM_MODEL,
"api_base": "some_api_base",
"api_version": "some_api_version",
"deployment_name": "some_deployment_name",
},
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
}
with pytest.raises(ValidationError):
create_graphrag_config({"models": model_config_missing_api_key})
# API Key not required for managed identity
model_config_missing_api_key[defs.DEFAULT_CHAT_MODEL_ID]["azure_auth_type"] = (
AzureAuthType.ManagedIdentity
)
create_graphrag_config({"models": model_config_missing_api_key})
def test_conflicting_azure_api_key() -> None:
model_config_conflicting_api_key = {
defs.DEFAULT_CHAT_MODEL_ID: {
"type": LLMType.AzureOpenAIChat,
"azure_auth_type": AzureAuthType.ManagedIdentity,
"model": defs.LLM_MODEL,
"api_base": "some_api_base",
"api_version": "some_api_version",
"deployment_name": "some_deployment_name",
"api_key": "THIS_SHOULD_NOT_BE_SET_WHEN_USING_MANAGED_IDENTITY",
},
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
}
with pytest.raises(ValidationError):
create_graphrag_config({"models": model_config_conflicting_api_key})
base_azure_model_config = {
"type": LLMType.AzureOpenAIChat,
"azure_auth_type": AzureAuthType.ManagedIdentity,
"model": defs.LLM_MODEL,
"api_base": "some_api_base",
"api_version": "some_api_version",
"deployment_name": "some_deployment_name",
}
def test_missing_azure_api_base() -> None:
missing_api_base_config = base_azure_model_config.copy()
del missing_api_base_config["api_base"]
with pytest.raises(ValidationError):
create_graphrag_config({
"models": {
defs.DEFAULT_CHAT_MODEL_ID: missing_api_base_config,
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
}
})
def test_missing_azure_api_version() -> None:
missing_api_version_config = base_azure_model_config.copy()
del missing_api_version_config["api_version"]
with pytest.raises(ValidationError):
create_graphrag_config({
"models": {
defs.DEFAULT_CHAT_MODEL_ID: missing_api_version_config,
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
}
})
def test_missing_azure_deployment_name() -> None:
missing_deployment_name_config = base_azure_model_config.copy()
del missing_deployment_name_config["deployment_name"]
with pytest.raises(ValidationError):
create_graphrag_config({
"models": {
defs.DEFAULT_CHAT_MODEL_ID: missing_deployment_name_config,
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
}
})
def test_default_config() -> None:
expected = get_default_graphrag_config()
actual = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
assert_graphrag_configs(actual, expected)
@mock.patch.dict(os.environ, {"CUSTOM_API_KEY": FAKE_API_KEY}, clear=True)
def test_load_minimal_config() -> None:
cwd = Path(__file__).parent
root_dir = (cwd / "fixtures" / "minimal_config").resolve()
expected = get_default_graphrag_config(str(root_dir))
actual = load_config(root_dir=root_dir)
assert_graphrag_configs(actual, expected)
@mock.patch.dict(os.environ, {"CUSTOM_API_KEY": FAKE_API_KEY}, clear=True)
def test_load_config_with_cli_overrides() -> None:
cwd = Path(__file__).parent
root_dir = (cwd / "fixtures" / "minimal_config").resolve()
output_dir = "some_output_dir"
expected_output_base_dir = root_dir / output_dir
expected = get_default_graphrag_config(str(root_dir))
expected.output.base_dir = str(expected_output_base_dir)
actual = load_config(
root_dir=root_dir, cli_overrides={"output.base_dir": output_dir}
)
assert_graphrag_configs(actual, expected)
def test_load_config_missing_env_vars() -> None:
cwd = Path(__file__).parent
root_dir = (cwd / "fixtures" / "minimal_config_missing_env_var").resolve()
with pytest.raises(KeyError):
load_config(root_dir=root_dir)

View File

@ -1,556 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
import json
import os
import re
import unittest
from pathlib import Path
from unittest import mock
import pytest
import yaml
import graphrag.config.defaults as defs
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.errors import (
ApiKeyMissingError,
AzureApiBaseMissingError,
AzureDeploymentNameMissingError,
)
from graphrag.config.models.basic_search_config import BasicSearchConfig
from graphrag.config.models.cache_config import CacheConfig
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.config.models.claim_extraction_config import ClaimExtractionConfig
from graphrag.config.models.cluster_graph_config import ClusterGraphConfig
from graphrag.config.models.community_reports_config import CommunityReportsConfig
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
from graphrag.config.models.entity_extraction_config import EntityExtractionConfig
from graphrag.config.models.global_search_config import GlobalSearchConfig
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.input_config import InputConfig
from graphrag.config.models.llm_parameters import LLMParameters
from graphrag.config.models.local_search_config import LocalSearchConfig
from graphrag.config.models.parallelization_parameters import ParallelizationParameters
from graphrag.config.models.reporting_config import ReportingConfig
from graphrag.config.models.snapshots_config import SnapshotsConfig
from graphrag.config.models.storage_config import StorageConfig
from graphrag.config.models.summarize_descriptions_config import (
SummarizeDescriptionsConfig,
)
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
from graphrag.config.models.umap_config import UmapConfig
from graphrag.index.config.cache import PipelineFileCacheConfig
from graphrag.index.config.input import (
PipelineInputConfig,
)
from graphrag.index.config.reporting import PipelineFileReportingConfig
from graphrag.index.config.storage import PipelineFileStorageConfig
from graphrag.index.create_pipeline_config import create_pipeline_config
current_dir = os.path.dirname(__file__)
ALL_ENV_VARS = {
"GRAPHRAG_API_BASE": "http://some/base",
"GRAPHRAG_API_KEY": "test",
"GRAPHRAG_API_ORGANIZATION": "test_org",
"GRAPHRAG_API_PROXY": "http://some/proxy",
"GRAPHRAG_API_VERSION": "v1234",
"GRAPHRAG_ASYNC_MODE": "asyncio",
"GRAPHRAG_CACHE_STORAGE_ACCOUNT_BLOB_URL": "cache_account_blob_url",
"GRAPHRAG_CACHE_BASE_DIR": "/some/cache/dir",
"GRAPHRAG_CACHE_CONNECTION_STRING": "test_cs1",
"GRAPHRAG_CACHE_CONTAINER_NAME": "test_cn1",
"GRAPHRAG_CACHE_TYPE": "blob",
"GRAPHRAG_CHUNK_BY_COLUMNS": "a,b",
"GRAPHRAG_CHUNK_OVERLAP": "12",
"GRAPHRAG_CHUNK_SIZE": "500",
"GRAPHRAG_CHUNK_ENCODING_MODEL": "encoding-c",
"GRAPHRAG_CLAIM_EXTRACTION_ENABLED": "True",
"GRAPHRAG_CLAIM_EXTRACTION_DESCRIPTION": "test 123",
"GRAPHRAG_CLAIM_EXTRACTION_MAX_GLEANINGS": "5000",
"GRAPHRAG_CLAIM_EXTRACTION_PROMPT_FILE": "tests/unit/config/prompt-a.txt",
"GRAPHRAG_CLAIM_EXTRACTION_ENCODING_MODEL": "encoding_a",
"GRAPHRAG_COMMUNITY_REPORTS_MAX_LENGTH": "23456",
"GRAPHRAG_COMMUNITY_REPORTS_PROMPT_FILE": "tests/unit/config/prompt-b.txt",
"GRAPHRAG_EMBEDDING_BATCH_MAX_TOKENS": "17",
"GRAPHRAG_EMBEDDING_BATCH_SIZE": "1000000",
"GRAPHRAG_EMBEDDING_CONCURRENT_REQUESTS": "12",
"GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME": "model-deployment-name",
"GRAPHRAG_EMBEDDING_MAX_RETRIES": "3",
"GRAPHRAG_EMBEDDING_MAX_RETRY_WAIT": "0.1123",
"GRAPHRAG_EMBEDDING_MODEL": "text-embedding-2",
"GRAPHRAG_EMBEDDING_REQUESTS_PER_MINUTE": "500",
"GRAPHRAG_EMBEDDING_SKIP": "a1,b1,c1",
"GRAPHRAG_EMBEDDING_SLEEP_ON_RATE_LIMIT_RECOMMENDATION": "False",
"GRAPHRAG_EMBEDDING_TARGET": "all",
"GRAPHRAG_EMBEDDING_THREAD_COUNT": "2345",
"GRAPHRAG_EMBEDDING_THREAD_STAGGER": "0.456",
"GRAPHRAG_EMBEDDING_TOKENS_PER_MINUTE": "7000",
"GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding",
"GRAPHRAG_ENCODING_MODEL": "test123",
"GRAPHRAG_INPUT_STORAGE_ACCOUNT_BLOB_URL": "input_account_blob_url",
"GRAPHRAG_ENTITY_EXTRACTION_ENTITY_TYPES": "cat,dog,elephant",
"GRAPHRAG_ENTITY_EXTRACTION_MAX_GLEANINGS": "112",
"GRAPHRAG_ENTITY_EXTRACTION_PROMPT_FILE": "tests/unit/config/prompt-c.txt",
"GRAPHRAG_ENTITY_EXTRACTION_ENCODING_MODEL": "encoding_b",
"GRAPHRAG_INPUT_BASE_DIR": "/some/input/dir",
"GRAPHRAG_INPUT_CONNECTION_STRING": "input_cs",
"GRAPHRAG_INPUT_CONTAINER_NAME": "input_cn",
"GRAPHRAG_INPUT_DOCUMENT_ATTRIBUTE_COLUMNS": "test1,test2",
"GRAPHRAG_INPUT_ENCODING": "utf-16",
"GRAPHRAG_INPUT_FILE_PATTERN": ".*\\test\\.txt$",
"GRAPHRAG_INPUT_SOURCE_COLUMN": "test_source",
"GRAPHRAG_INPUT_TYPE": "blob",
"GRAPHRAG_INPUT_TEXT_COLUMN": "test_text",
"GRAPHRAG_INPUT_TIMESTAMP_COLUMN": "test_timestamp",
"GRAPHRAG_INPUT_TIMESTAMP_FORMAT": "test_format",
"GRAPHRAG_INPUT_TITLE_COLUMN": "test_title",
"GRAPHRAG_INPUT_FILE_TYPE": "text",
"GRAPHRAG_LLM_CONCURRENT_REQUESTS": "12",
"GRAPHRAG_LLM_DEPLOYMENT_NAME": "model-deployment-name-x",
"GRAPHRAG_LLM_MAX_RETRIES": "312",
"GRAPHRAG_LLM_MAX_RETRY_WAIT": "0.1122",
"GRAPHRAG_LLM_MAX_TOKENS": "15000",
"GRAPHRAG_LLM_MODEL_SUPPORTS_JSON": "true",
"GRAPHRAG_LLM_MODEL": "test-llm",
"GRAPHRAG_LLM_N": "1",
"GRAPHRAG_LLM_REQUEST_TIMEOUT": "12.7",
"GRAPHRAG_LLM_REQUESTS_PER_MINUTE": "900",
"GRAPHRAG_LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION": "False",
"GRAPHRAG_LLM_THREAD_COUNT": "987",
"GRAPHRAG_LLM_THREAD_STAGGER": "0.123",
"GRAPHRAG_LLM_TOKENS_PER_MINUTE": "8000",
"GRAPHRAG_LLM_TYPE": "azure_openai_chat",
"GRAPHRAG_MAX_CLUSTER_SIZE": "123",
"GRAPHRAG_NODE2VEC_ENABLED": "true",
"GRAPHRAG_NODE2VEC_ITERATIONS": "878787",
"GRAPHRAG_NODE2VEC_NUM_WALKS": "5000000",
"GRAPHRAG_NODE2VEC_RANDOM_SEED": "010101",
"GRAPHRAG_NODE2VEC_WALK_LENGTH": "555111",
"GRAPHRAG_NODE2VEC_WINDOW_SIZE": "12345",
"GRAPHRAG_REPORTING_STORAGE_ACCOUNT_BLOB_URL": "reporting_account_blob_url",
"GRAPHRAG_REPORTING_BASE_DIR": "/some/reporting/dir",
"GRAPHRAG_REPORTING_CONNECTION_STRING": "test_cs2",
"GRAPHRAG_REPORTING_CONTAINER_NAME": "test_cn2",
"GRAPHRAG_REPORTING_TYPE": "blob",
"GRAPHRAG_SKIP_WORKFLOWS": "a,b,c",
"GRAPHRAG_SNAPSHOT_GRAPHML": "true",
"GRAPHRAG_SNAPSHOT_RAW_ENTITIES": "true",
"GRAPHRAG_SNAPSHOT_TOP_LEVEL_NODES": "true",
"GRAPHRAG_SNAPSHOT_EMBEDDINGS": "true",
"GRAPHRAG_SNAPSHOT_TRANSIENT": "true",
"GRAPHRAG_STORAGE_STORAGE_ACCOUNT_BLOB_URL": "storage_account_blob_url",
"GRAPHRAG_STORAGE_BASE_DIR": "/some/storage/dir",
"GRAPHRAG_STORAGE_CONNECTION_STRING": "test_cs",
"GRAPHRAG_STORAGE_CONTAINER_NAME": "test_cn",
"GRAPHRAG_STORAGE_TYPE": "blob",
"GRAPHRAG_SUMMARIZE_DESCRIPTIONS_MAX_LENGTH": "12345",
"GRAPHRAG_SUMMARIZE_DESCRIPTIONS_PROMPT_FILE": "tests/unit/config/prompt-d.txt",
"GRAPHRAG_LLM_TEMPERATURE": "0.0",
"GRAPHRAG_LLM_TOP_P": "1.0",
"GRAPHRAG_UMAP_ENABLED": "true",
"GRAPHRAG_LOCAL_SEARCH_TEXT_UNIT_PROP": "0.713",
"GRAPHRAG_LOCAL_SEARCH_COMMUNITY_PROP": "0.1234",
"GRAPHRAG_LOCAL_SEARCH_LLM_TEMPERATURE": "0.1",
"GRAPHRAG_LOCAL_SEARCH_LLM_TOP_P": "0.9",
"GRAPHRAG_LOCAL_SEARCH_LLM_N": "2",
"GRAPHRAG_LOCAL_SEARCH_LLM_MAX_TOKENS": "12",
"GRAPHRAG_LOCAL_SEARCH_TOP_K_RELATIONSHIPS": "15",
"GRAPHRAG_LOCAL_SEARCH_TOP_K_ENTITIES": "14",
"GRAPHRAG_LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS": "2",
"GRAPHRAG_LOCAL_SEARCH_MAX_TOKENS": "142435",
"GRAPHRAG_GLOBAL_SEARCH_LLM_TEMPERATURE": "0.1",
"GRAPHRAG_GLOBAL_SEARCH_LLM_TOP_P": "0.9",
"GRAPHRAG_GLOBAL_SEARCH_LLM_N": "2",
"GRAPHRAG_GLOBAL_SEARCH_MAX_TOKENS": "5123",
"GRAPHRAG_GLOBAL_SEARCH_DATA_MAX_TOKENS": "123",
"GRAPHRAG_GLOBAL_SEARCH_MAP_MAX_TOKENS": "4123",
"GRAPHRAG_GLOBAL_SEARCH_CONCURRENCY": "7",
"GRAPHRAG_GLOBAL_SEARCH_REDUCE_MAX_TOKENS": "15432",
}
class TestDefaultConfig(unittest.TestCase):
def test_clear_warnings(self):
"""Just clearing unused import warnings"""
assert CacheConfig is not None
assert ChunkingConfig is not None
assert ClaimExtractionConfig is not None
assert ClusterGraphConfig is not None
assert CommunityReportsConfig is not None
assert DRIFTSearchConfig is not None
assert EmbedGraphConfig is not None
assert EntityExtractionConfig is not None
assert GlobalSearchConfig is not None
assert GraphRagConfig is not None
assert InputConfig is not None
assert LLMParameters is not None
assert LocalSearchConfig is not None
assert BasicSearchConfig is not None
assert ParallelizationParameters is not None
assert ReportingConfig is not None
assert SnapshotsConfig is not None
assert StorageConfig is not None
assert SummarizeDescriptionsConfig is not None
assert TextEmbeddingConfig is not None
assert UmapConfig is not None
assert PipelineFileReportingConfig is not None
assert PipelineFileStorageConfig is not None
assert PipelineInputConfig is not None
assert PipelineFileCacheConfig is not None
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "test"}, clear=True)
def test_string_repr(self):
# __str__ can be json loaded
config = create_graphrag_config()
string_repr = str(config)
assert string_repr is not None
assert json.loads(string_repr) is not None
# __repr__ can be eval()'d
repr_str = config.__repr__()
# TODO: add __repr__ to enum
repr_str = repr_str.replace("async_mode=<AsyncType.Threaded: 'threaded'>,", "")
assert eval(repr_str) is not None
@mock.patch.dict(os.environ, {}, clear=True)
def test_default_config_with_no_env_vars_throws(self):
with pytest.raises(ApiKeyMissingError):
# This should throw an error because the API key is missing
create_graphrag_config()
@mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True)
def test_default_config_with_api_key_passes(self):
# doesn't throw
config = create_graphrag_config()
assert config is not None
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "test"}, clear=True)
def test_default_config_with_oai_key_passes_envvar(self):
# doesn't throw
config = create_graphrag_config()
assert config is not None
def test_default_config_with_oai_key_passes_obj(self):
# doesn't throw
config = create_graphrag_config({"llm": {"api_key": "test"}})
assert config is not None
@mock.patch.dict(
os.environ,
{"GRAPHRAG_API_KEY": "test", "GRAPHRAG_LLM_TYPE": "azure_openai_chat"},
clear=True,
)
def test_throws_if_azure_is_used_without_api_base_envvar(self):
with pytest.raises(AzureApiBaseMissingError):
create_graphrag_config()
@mock.patch.dict(
os.environ,
{
"GRAPHRAG_API_KEY": "test",
"GRAPHRAG_LLM_TYPE": "azure_openai_chat",
"GRAPHRAG_API_BASE": "http://some/base",
},
clear=True,
)
def test_throws_if_azure_is_used_without_llm_deployment_name_envvar(self):
with pytest.raises(AzureDeploymentNameMissingError):
create_graphrag_config()
@mock.patch.dict(
os.environ,
{
"GRAPHRAG_API_KEY": "test",
"GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding",
"GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME": "x",
},
clear=True,
)
def test_throws_if_azure_is_used_without_embedding_api_base_envvar(self):
with pytest.raises(AzureApiBaseMissingError):
create_graphrag_config()
@mock.patch.dict(
os.environ,
{
"GRAPHRAG_API_KEY": "test",
"GRAPHRAG_API_BASE": "http://some/base",
"GRAPHRAG_LLM_DEPLOYMENT_NAME": "x",
"GRAPHRAG_LLM_TYPE": "azure_openai_chat",
"GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding",
},
clear=True,
)
def test_throws_if_azure_is_used_without_embedding_deployment_name_envvar(self):
with pytest.raises(AzureDeploymentNameMissingError):
create_graphrag_config()
@mock.patch.dict(
os.environ,
{
"GRAPHRAG_API_KEY": "test",
"GRAPHRAG_LLM_TYPE": "azure_openai_chat",
"GRAPHRAG_LLM_DEPLOYMENT_NAME": "x",
},
clear=True,
)
def test_throws_if_azure_is_used_without_api_base(self):
with pytest.raises(AzureApiBaseMissingError):
create_graphrag_config()
@mock.patch.dict(
os.environ,
{
"GRAPHRAG_API_KEY": "test",
"GRAPHRAG_LLM_TYPE": "azure_openai_chat",
"GRAPHRAG_LLM_API_BASE": "http://some/base",
},
clear=True,
)
def test_throws_if_azure_is_used_without_llm_deployment_name(self):
with pytest.raises(AzureDeploymentNameMissingError):
create_graphrag_config()
@mock.patch.dict(
os.environ,
{
"GRAPHRAG_API_KEY": "test",
"GRAPHRAG_LLM_TYPE": "azure_openai_chat",
"GRAPHRAG_API_BASE": "http://some/base",
"GRAPHRAG_LLM_DEPLOYMENT_NAME": "model-deployment-name-x",
"GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding",
},
clear=True,
)
def test_throws_if_azure_is_used_without_embedding_deployment_name(self):
with pytest.raises(AzureDeploymentNameMissingError):
create_graphrag_config()
@mock.patch.dict(
os.environ,
{
"GRAPHRAG_LLM_API_KEY": "test",
"GRAPHRAG_ENTITY_EXTRACTION_MAX_GLEANINGS": "0",
"GRAPHRAG_CLAIM_EXTRACTION_MAX_GLEANINGS": "0",
},
clear=True,
)
def test_can_set_gleanings_to_zero(self):
parameters = create_graphrag_config()
assert parameters.claim_extraction.max_gleanings == 0
assert parameters.entity_extraction.max_gleanings == 0
@mock.patch.dict(
os.environ,
{"GRAPHRAG_LLM_API_KEY": "test", "GRAPHRAG_CHUNK_BY_COLUMNS": ""},
clear=True,
)
def test_can_set_no_chunk_by_columns(self):
parameters = create_graphrag_config()
assert parameters.chunks.group_by_columns == []
def test_all_env_vars_is_accurate(self):
env_var_docs_path = Path("docs/config/env_vars.md")
env_var_docs = env_var_docs_path.read_text(encoding="utf-8")
def find_envvar_names(text) -> set[str]:
pattern = r"`(GRAPHRAG_[^`]+)`"
found = re.findall(pattern, text)
found = {f for f in found if not f.endswith("_")}
return {*found}
graphrag_strings = find_envvar_names(env_var_docs)
missing = {s for s in graphrag_strings if s not in ALL_ENV_VARS} - {
# Remove configs covered by the base LLM connection configs
"GRAPHRAG_LLM_API_KEY",
"GRAPHRAG_LLM_API_BASE",
"GRAPHRAG_LLM_API_VERSION",
"GRAPHRAG_LLM_API_ORGANIZATION",
"GRAPHRAG_LLM_API_PROXY",
"GRAPHRAG_EMBEDDING_API_KEY",
"GRAPHRAG_EMBEDDING_API_BASE",
"GRAPHRAG_EMBEDDING_API_VERSION",
"GRAPHRAG_EMBEDDING_API_ORGANIZATION",
"GRAPHRAG_EMBEDDING_API_PROXY",
}
if missing:
msg = f"{len(missing)} missing env vars: {missing}"
print(msg)
raise ValueError(msg)
@mock.patch.dict(
os.environ,
{"GRAPHRAG_API_KEY": "test"},
clear=True,
)
@mock.patch.dict(
os.environ,
ALL_ENV_VARS,
clear=True,
)
@mock.patch.dict(
os.environ,
{"GRAPHRAG_API_KEY": "test"},
clear=True,
)
def test_default_values(self) -> None:
parameters = create_graphrag_config()
assert parameters.async_mode == defs.ASYNC_MODE
assert parameters.cache.base_dir == defs.CACHE_BASE_DIR
assert parameters.cache.type == defs.CACHE_TYPE
assert parameters.cache.base_dir == defs.CACHE_BASE_DIR
assert parameters.chunks.group_by_columns == defs.CHUNK_GROUP_BY_COLUMNS
assert parameters.chunks.overlap == defs.CHUNK_OVERLAP
assert parameters.chunks.size == defs.CHUNK_SIZE
assert parameters.claim_extraction.description == defs.CLAIM_DESCRIPTION
assert parameters.claim_extraction.max_gleanings == defs.CLAIM_MAX_GLEANINGS
assert (
parameters.community_reports.max_input_length
== defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH
)
assert (
parameters.community_reports.max_length == defs.COMMUNITY_REPORT_MAX_LENGTH
)
assert parameters.embeddings.batch_max_tokens == defs.EMBEDDING_BATCH_MAX_TOKENS
assert parameters.embeddings.batch_size == defs.EMBEDDING_BATCH_SIZE
assert parameters.embeddings.llm.model == defs.EMBEDDING_MODEL
assert parameters.embeddings.target == defs.EMBEDDING_TARGET
assert parameters.embeddings.llm.type == defs.EMBEDDING_TYPE
assert (
parameters.embeddings.llm.requests_per_minute
== defs.LLM_REQUESTS_PER_MINUTE
)
assert parameters.embeddings.llm.tokens_per_minute == defs.LLM_TOKENS_PER_MINUTE
assert (
parameters.embeddings.llm.sleep_on_rate_limit_recommendation
== defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION
)
assert (
parameters.entity_extraction.entity_types
== defs.ENTITY_EXTRACTION_ENTITY_TYPES
)
assert (
parameters.entity_extraction.max_gleanings
== defs.ENTITY_EXTRACTION_MAX_GLEANINGS
)
assert parameters.encoding_model == defs.ENCODING_MODEL
assert parameters.input.base_dir == defs.INPUT_BASE_DIR
assert parameters.input.file_pattern == defs.INPUT_CSV_PATTERN
assert parameters.input.encoding == defs.INPUT_FILE_ENCODING
assert parameters.input.type == defs.INPUT_TYPE
assert parameters.input.base_dir == defs.INPUT_BASE_DIR
assert parameters.input.text_column == defs.INPUT_TEXT_COLUMN
assert parameters.input.file_type == defs.INPUT_FILE_TYPE
assert parameters.llm.concurrent_requests == defs.LLM_CONCURRENT_REQUESTS
assert parameters.llm.max_retries == defs.LLM_MAX_RETRIES
assert parameters.llm.max_retry_wait == defs.LLM_MAX_RETRY_WAIT
assert parameters.llm.max_tokens == defs.LLM_MAX_TOKENS
assert parameters.llm.model == defs.LLM_MODEL
assert parameters.llm.request_timeout == defs.LLM_REQUEST_TIMEOUT
assert parameters.llm.requests_per_minute == defs.LLM_REQUESTS_PER_MINUTE
assert parameters.llm.tokens_per_minute == defs.LLM_TOKENS_PER_MINUTE
assert (
parameters.llm.sleep_on_rate_limit_recommendation
== defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION
)
assert parameters.llm.type == defs.LLM_TYPE
assert parameters.cluster_graph.max_cluster_size == defs.MAX_CLUSTER_SIZE
assert parameters.embed_graph.enabled == defs.NODE2VEC_ENABLED
assert parameters.embed_graph.iterations == defs.NODE2VEC_ITERATIONS
assert parameters.embed_graph.num_walks == defs.NODE2VEC_NUM_WALKS
assert parameters.embed_graph.random_seed == defs.NODE2VEC_RANDOM_SEED
assert parameters.embed_graph.walk_length == defs.NODE2VEC_WALK_LENGTH
assert parameters.embed_graph.window_size == defs.NODE2VEC_WINDOW_SIZE
assert (
parameters.parallelization.num_threads == defs.PARALLELIZATION_NUM_THREADS
)
assert parameters.parallelization.stagger == defs.PARALLELIZATION_STAGGER
assert parameters.reporting.type == defs.REPORTING_TYPE
assert parameters.reporting.base_dir == defs.REPORTING_BASE_DIR
assert parameters.snapshots.graphml == defs.SNAPSHOTS_GRAPHML
assert parameters.snapshots.embeddings == defs.SNAPSHOTS_EMBEDDINGS
assert parameters.snapshots.transient == defs.SNAPSHOTS_TRANSIENT
assert parameters.storage.base_dir == defs.STORAGE_BASE_DIR
assert parameters.storage.type == defs.STORAGE_TYPE
assert parameters.umap.enabled == defs.UMAP_ENABLED
@mock.patch.dict(
os.environ,
{"GRAPHRAG_API_KEY": "test"},
clear=True,
)
def test_prompt_file_reading(self):
config = create_graphrag_config({
"entity_extraction": {"prompt": "tests/unit/config/prompt-a.txt"},
"claim_extraction": {"prompt": "tests/unit/config/prompt-b.txt"},
"community_reports": {"prompt": "tests/unit/config/prompt-c.txt"},
"summarize_descriptions": {"prompt": "tests/unit/config/prompt-d.txt"},
})
strategy = config.entity_extraction.resolved_strategy(".", "abc123")
assert strategy["extraction_prompt"] == "Hello, World! A"
assert strategy["encoding_name"] == "abc123"
strategy = config.claim_extraction.resolved_strategy(".", "encoding_b")
assert strategy["extraction_prompt"] == "Hello, World! B"
strategy = config.community_reports.resolved_strategy(".")
assert strategy["extraction_prompt"] == "Hello, World! C"
strategy = config.summarize_descriptions.resolved_strategy(".")
assert strategy["summarize_prompt"] == "Hello, World! D"
@mock.patch.dict(
os.environ,
{
"PIPELINE_LLM_API_KEY": "test",
"PIPELINE_LLM_API_BASE": "http://test",
"PIPELINE_LLM_API_VERSION": "v1",
"PIPELINE_LLM_MODEL": "test-llm",
"PIPELINE_LLM_DEPLOYMENT_NAME": "test",
},
clear=True,
)
def test_yaml_load_e2e():
config_dict = yaml.safe_load(
"""
input:
file_type: text
llm:
type: azure_openai_chat
api_key: ${PIPELINE_LLM_API_KEY}
api_base: ${PIPELINE_LLM_API_BASE}
api_version: ${PIPELINE_LLM_API_VERSION}
model: ${PIPELINE_LLM_MODEL}
deployment_name: ${PIPELINE_LLM_DEPLOYMENT_NAME}
model_supports_json: True
tokens_per_minute: 80000
requests_per_minute: 900
thread_count: 50
concurrent_requests: 25
"""
)
# create default configuration pipeline parameters from the custom settings
model = config_dict
parameters = create_graphrag_config(model, ".")
assert parameters.llm.api_key == "test"
assert parameters.llm.model == "test-llm"
assert parameters.llm.api_base == "http://test"
assert parameters.llm.api_version == "v1"
assert parameters.llm.deployment_name == "test"
# generate the pipeline from the default parameters
pipeline_config = create_pipeline_config(parameters, True)
config_str = pipeline_config.model_dump_json()
assert "${PIPELINE_LLM_API_KEY}" not in config_str
assert "${PIPELINE_LLM_API_BASE}" not in config_str
assert "${PIPELINE_LLM_API_VERSION}" not in config_str
assert "${PIPELINE_LLM_MODEL}" not in config_str
assert "${PIPELINE_LLM_DEPLOYMENT_NAME}" not in config_str

View File

@ -1,42 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from pathlib import Path
from graphrag.config.resolve_path import resolve_path
def test_resolve_path_no_timestamp_with_run_id():
path = Path("path/to/data")
result = resolve_path(path, pattern_or_timestamp_value="20240812-121000")
assert result == path
def test_resolve_path_no_timestamp_without_run_id():
path = Path("path/to/data")
result = resolve_path(path)
assert result == path
def test_resolve_path_with_timestamp_and_run_id():
path = Path("some/path/${timestamp}/data")
expected = Path("some/path/20240812/data")
result = resolve_path(path, pattern_or_timestamp_value="20240812")
assert result == expected
def test_resolve_path_with_timestamp_and_inferred_directory():
cwd = Path(__file__).parent
path = cwd / "fixtures/timestamp_dirs/${timestamp}/data"
expected = cwd / "fixtures/timestamp_dirs/20240812-120000/data"
result = resolve_path(path)
assert result == expected
def test_resolve_path_absolute():
cwd = Path(__file__).parent
path = "fixtures/timestamp_dirs/${timestamp}/data"
expected = cwd / "fixtures/timestamp_dirs/20240812-120000/data"
result = resolve_path(path, cwd)
assert result == expected
assert result.is_absolute()

569
tests/unit/config/utils.py Normal file
View File

@ -0,0 +1,569 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from pydantic import BaseModel
import graphrag.config.defaults as defs
from graphrag.config.models.basic_search_config import BasicSearchConfig
from graphrag.config.models.cache_config import CacheConfig
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.config.models.claim_extraction_config import ClaimExtractionConfig
from graphrag.config.models.cluster_graph_config import ClusterGraphConfig
from graphrag.config.models.community_reports_config import CommunityReportsConfig
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
from graphrag.config.models.entity_extraction_config import EntityExtractionConfig
from graphrag.config.models.global_search_config import GlobalSearchConfig
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.input_config import InputConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.config.models.local_search_config import LocalSearchConfig
from graphrag.config.models.output_config import OutputConfig
from graphrag.config.models.reporting_config import ReportingConfig
from graphrag.config.models.snapshots_config import SnapshotsConfig
from graphrag.config.models.summarize_descriptions_config import (
SummarizeDescriptionsConfig,
)
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
from graphrag.config.models.umap_config import UmapConfig
from graphrag.config.models.vector_store_config import VectorStoreConfig
FAKE_API_KEY = "NOT_AN_API_KEY"
DEFAULT_CHAT_MODEL_CONFIG = {
"api_key": FAKE_API_KEY,
"type": defs.LLM_TYPE.value,
"model": defs.LLM_MODEL,
}
DEFAULT_EMBEDDING_MODEL_CONFIG = {
"api_key": FAKE_API_KEY,
"type": defs.EMBEDDING_TYPE.value,
"model": defs.EMBEDDING_MODEL,
}
DEFAULT_MODEL_CONFIG = {
defs.DEFAULT_CHAT_MODEL_ID: DEFAULT_CHAT_MODEL_CONFIG,
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
}
DEFAULT_GRAPHRAG_CONFIG_SETTINGS = {
"models": DEFAULT_MODEL_CONFIG,
"vector_store": {
"type": defs.VECTOR_STORE_TYPE,
"db_uri": defs.VECTOR_STORE_DB_URI,
"container_name": defs.VECTOR_STORE_CONTAINER_NAME,
"overwrite": defs.VECTOR_STORE_OVERWRITE,
"url": None,
"api_key": None,
"audience": None,
},
"reporting": {
"type": defs.REPORTING_TYPE,
"base_dir": defs.REPORTING_BASE_DIR,
"connection_string": None,
"container_name": None,
"storage_account_blob_url": None,
},
"output": {
"type": defs.OUTPUT_TYPE,
"base_dir": defs.OUTPUT_BASE_DIR,
"connection_string": None,
"container_name": None,
"storage_account_blob_url": None,
},
"update_index_output": None,
"cache": {
"type": defs.CACHE_TYPE,
"base_dir": defs.CACHE_BASE_DIR,
"connection_string": None,
"container_name": None,
"storage_account_blob_url": None,
"cosmosdb_account_url": None,
},
"input": {
"type": defs.INPUT_TYPE,
"file_type": defs.INPUT_FILE_TYPE,
"base_dir": defs.INPUT_BASE_DIR,
"connection_string": None,
"storage_account_blob_url": None,
"container_name": None,
"encoding": defs.INPUT_FILE_ENCODING,
"file_pattern": defs.INPUT_TEXT_PATTERN,
"file_filter": None,
"source_column": None,
"timestamp_column": None,
"timestamp_format": None,
"text_column": defs.INPUT_TEXT_COLUMN,
"title_column": None,
"document_attribute_columns": [],
},
"embed_graph": {
"enabled": defs.NODE2VEC_ENABLED,
"dimensions": defs.NODE2VEC_DIMENSIONS,
"num_walks": defs.NODE2VEC_NUM_WALKS,
"walk_length": defs.NODE2VEC_WALK_LENGTH,
"window_size": defs.NODE2VEC_WINDOW_SIZE,
"iterations": defs.NODE2VEC_ITERATIONS,
"random_seed": defs.NODE2VEC_RANDOM_SEED,
"use_lcc": defs.USE_LCC,
},
"embeddings": {
"batch_size": defs.EMBEDDING_BATCH_SIZE,
"batch_max_tokens": defs.EMBEDDING_BATCH_MAX_TOKENS,
"target": defs.EMBEDDING_TARGET,
"strategy": None,
"model_id": defs.EMBEDDING_MODEL_ID,
},
"chunks": {
"size": defs.CHUNK_SIZE,
"overlap": defs.CHUNK_OVERLAP,
"group_by_columns": defs.CHUNK_GROUP_BY_COLUMNS,
"strategy": defs.CHUNK_STRATEGY,
"encoding_model": defs.ENCODING_MODEL,
},
"snapshots": {
"embeddings": defs.SNAPSHOTS_EMBEDDINGS,
"graphml": defs.SNAPSHOTS_GRAPHML,
"transient": defs.SNAPSHOTS_TRANSIENT,
},
"entity_extraction": {
"prompt": None,
"entity_types": defs.ENTITY_EXTRACTION_ENTITY_TYPES,
"max_gleanings": defs.ENTITY_EXTRACTION_MAX_GLEANINGS,
"strategy": None,
"encoding_model": None,
"model_id": defs.ENTITY_EXTRACTION_MODEL_ID,
},
"summarize_descriptions": {
"prompt": None,
"max_length": defs.SUMMARIZE_DESCRIPTIONS_MAX_LENGTH,
"strategy": None,
"model_id": defs.SUMMARIZE_MODEL_ID,
},
"community_report": {
"prompt": None,
"max_length": defs.COMMUNITY_REPORT_MAX_LENGTH,
"max_input_length": defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH,
"strategy": None,
"model_id": defs.COMMUNITY_REPORT_MODEL_ID,
},
"claim_extaction": {
"enabled": defs.CLAIM_EXTRACTION_ENABLED,
"prompt": None,
"description": defs.CLAIM_DESCRIPTION,
"max_gleanings": defs.CLAIM_MAX_GLEANINGS,
"strategy": None,
"encoding_model": None,
"model_id": defs.CLAIM_EXTRACTION_MODEL_ID,
},
"cluster_graph": {
"max_cluster_size": defs.MAX_CLUSTER_SIZE,
"use_lcc": defs.USE_LCC,
"seed": defs.CLUSTER_GRAPH_SEED,
},
"umap": {"enabled": defs.UMAP_ENABLED},
"local_search": {
"prompt": None,
"text_unit_prop": defs.LOCAL_SEARCH_TEXT_UNIT_PROP,
"community_prop": defs.LOCAL_SEARCH_COMMUNITY_PROP,
"conversation_history_max_turns": defs.LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS,
"top_k_entities": defs.LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES,
"top_k_relationships": defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS,
"temperature": defs.LOCAL_SEARCH_LLM_TEMPERATURE,
"top_p": defs.LOCAL_SEARCH_LLM_TOP_P,
"n": defs.LOCAL_SEARCH_LLM_N,
"max_tokens": defs.LOCAL_SEARCH_MAX_TOKENS,
"llm_max_tokens": defs.LOCAL_SEARCH_LLM_MAX_TOKENS,
},
"global_search": {
"map_prompt": None,
"reduce_prompt": None,
"knowledge_prompt": None,
"temperature": defs.GLOBAL_SEARCH_LLM_TEMPERATURE,
"top_p": defs.GLOBAL_SEARCH_LLM_TOP_P,
"n": defs.GLOBAL_SEARCH_LLM_N,
"max_tokens": defs.GLOBAL_SEARCH_MAX_TOKENS,
"data_max_tokens": defs.GLOBAL_SEARCH_DATA_MAX_TOKENS,
"map_max_tokens": defs.GLOBAL_SEARCH_MAP_MAX_TOKENS,
"reduce_max_tokens": defs.GLOBAL_SEARCH_REDUCE_MAX_TOKENS,
"concurrency": defs.GLOBAL_SEARCH_CONCURRENCY,
"dynamic_search_llm": defs.DYNAMIC_SEARCH_LLM_MODEL,
"dynamic_search_threshold": defs.DYNAMIC_SEARCH_RATE_THRESHOLD,
"dynamic_search_keep_parent": defs.DYNAMIC_SEARCH_KEEP_PARENT,
"dynamic_search_num_repeats": defs.DYNAMIC_SEARCH_NUM_REPEATS,
"dynamic_search_use_summary": defs.DYNAMIC_SEARCH_USE_SUMMARY,
"dynamic_search_concurrent_coroutines": defs.DYNAMIC_SEARCH_CONCURRENT_COROUTINES,
"dynamic_search_max_level": defs.DYNAMIC_SEARCH_MAX_LEVEL,
},
"drift_search": {
"prompt": None,
"temperature": defs.DRIFT_SEARCH_LLM_TEMPERATURE,
"top_p": defs.DRIFT_SEARCH_LLM_TOP_P,
"n": defs.DRIFT_SEARCH_LLM_N,
"max_tokens": defs.DRIFT_SEARCH_MAX_TOKENS,
"data_max_tokens": defs.DRIFT_SEARCH_DATA_MAX_TOKENS,
"concurrency": defs.DRIFT_SEARCH_CONCURRENCY,
"drift_k_followups": defs.DRIFT_SEARCH_K_FOLLOW_UPS,
"primer_folds": defs.DRIFT_SEARCH_PRIMER_FOLDS,
"primer_llm_max_tokens": defs.DRIFT_SEARCH_PRIMER_MAX_TOKENS,
"n_depth": defs.DRIFT_N_DEPTH,
"local_search_text_unit_prop": defs.DRIFT_LOCAL_SEARCH_TEXT_UNIT_PROP,
"local_search_community_prop": defs.DRIFT_LOCAL_SEARCH_COMMUNITY_PROP,
"local_search_top_k_mapped_entities": defs.DRIFT_LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES,
"local_search_top_k_relationships": defs.DRIFT_LOCAL_SEARCH_TOP_K_RELATIONSHIPS,
"local_search_max_data_tokens": defs.DRIFT_LOCAL_SEARCH_MAX_TOKENS,
"local_search_temperature": defs.DRIFT_LOCAL_SEARCH_LLM_TEMPERATURE,
"local_search_top_p": defs.DRIFT_LOCAL_SEARCH_LLM_TOP_P,
"local_search_n": defs.DRIFT_LOCAL_SEARCH_LLM_N,
"local_search_max_tokens": defs.DRIFT_LOCAL_SEARCH_MAX_TOKENS,
},
"basic_search": {
"prompt": None,
"text_unit_prop": defs.BASIC_SEARCH_TEXT_UNIT_PROP,
"conversation_history_max_turns": defs.BASIC_SEARCH_CONVERSATION_HISTORY_MAX_TURNS,
"temperature": defs.BASIC_SEARCH_LLM_TEMPERATURE,
"top_p": defs.BASIC_SEARCH_LLM_TOP_P,
"n": defs.BASIC_SEARCH_LLM_N,
"max_tokens": defs.BASIC_SEARCH_MAX_TOKENS,
"llm_max_tokens": defs.BASIC_SEARCH_LLM_MAX_TOKENS,
},
}
def get_default_graphrag_config(root_dir: str | None = None) -> GraphRagConfig:
if root_dir is not None and root_dir.strip() != "":
DEFAULT_GRAPHRAG_CONFIG_SETTINGS["root_dir"] = root_dir
return GraphRagConfig(**DEFAULT_GRAPHRAG_CONFIG_SETTINGS)
def assert_language_model_configs(
actual: LanguageModelConfig, expected: LanguageModelConfig
) -> None:
assert actual.api_key == expected.api_key
assert actual.azure_auth_type == expected.azure_auth_type
assert actual.type == expected.type
assert actual.model == expected.model
assert actual.encoding_model == expected.encoding_model
assert actual.max_tokens == expected.max_tokens
assert actual.temperature == expected.temperature
assert actual.top_p == expected.top_p
assert actual.n == expected.n
assert actual.frequency_penalty == expected.frequency_penalty
assert actual.presence_penalty == expected.presence_penalty
assert actual.request_timeout == expected.request_timeout
assert actual.api_base == expected.api_base
assert actual.api_version == expected.api_version
assert actual.deployment_name == expected.deployment_name
assert actual.organization == expected.organization
assert actual.proxy == expected.proxy
assert actual.audience == expected.audience
assert actual.model_supports_json == expected.model_supports_json
assert actual.tokens_per_minute == expected.tokens_per_minute
assert actual.requests_per_minute == expected.requests_per_minute
assert actual.max_retries == expected.max_retries
assert actual.max_retry_wait == expected.max_retry_wait
assert (
actual.sleep_on_rate_limit_recommendation
== expected.sleep_on_rate_limit_recommendation
)
assert actual.concurrent_requests == expected.concurrent_requests
assert actual.parallelization_stagger == expected.parallelization_stagger
assert actual.parallelization_num_threads == expected.parallelization_num_threads
assert actual.async_mode == expected.async_mode
if actual.responses is not None:
assert expected.responses is not None
assert len(actual.responses) == len(expected.responses)
for e, a in zip(actual.responses, expected.responses, strict=True):
assert isinstance(e, BaseModel)
assert isinstance(a, BaseModel)
assert e.model_dump() == a.model_dump()
else:
assert expected.responses is None
def assert_vector_store_configs(actual: VectorStoreConfig, expected: VectorStoreConfig):
assert actual.type == expected.type
assert actual.db_uri == expected.db_uri
assert actual.container_name == expected.container_name
assert actual.overwrite == expected.overwrite
assert actual.url == expected.url
assert actual.api_key == expected.api_key
assert actual.audience == expected.audience
def assert_reporting_configs(
actual: ReportingConfig, expected: ReportingConfig
) -> None:
assert actual.type == expected.type
assert actual.base_dir == expected.base_dir
assert actual.connection_string == expected.connection_string
assert actual.container_name == expected.container_name
assert actual.storage_account_blob_url == expected.storage_account_blob_url
def assert_output_configs(actual: OutputConfig, expected: OutputConfig) -> None:
assert expected.type == actual.type
assert expected.base_dir == actual.base_dir
assert expected.connection_string == actual.connection_string
assert expected.container_name == actual.container_name
assert expected.storage_account_blob_url == actual.storage_account_blob_url
assert expected.cosmosdb_account_url == actual.cosmosdb_account_url
def assert_cache_configs(actual: CacheConfig, expected: CacheConfig) -> None:
assert actual.type == expected.type
assert actual.base_dir == expected.base_dir
assert actual.connection_string == expected.connection_string
assert actual.container_name == expected.container_name
assert actual.storage_account_blob_url == expected.storage_account_blob_url
assert actual.cosmosdb_account_url == expected.cosmosdb_account_url
def assert_input_configs(actual: InputConfig, expected: InputConfig) -> None:
assert actual.type == expected.type
assert actual.file_type == expected.file_type
assert actual.base_dir == expected.base_dir
assert actual.connection_string == expected.connection_string
assert actual.storage_account_blob_url == expected.storage_account_blob_url
assert actual.container_name == expected.container_name
assert actual.encoding == expected.encoding
assert actual.file_pattern == expected.file_pattern
assert actual.file_filter == expected.file_filter
assert actual.source_column == expected.source_column
assert actual.timestamp_column == expected.timestamp_column
assert actual.timestamp_format == expected.timestamp_format
assert actual.text_column == expected.text_column
assert actual.title_column == expected.title_column
assert actual.document_attribute_columns == expected.document_attribute_columns
def assert_embed_graph_configs(
actual: EmbedGraphConfig, expected: EmbedGraphConfig
) -> None:
assert actual.enabled == expected.enabled
assert actual.dimensions == expected.dimensions
assert actual.num_walks == expected.num_walks
assert actual.walk_length == expected.walk_length
assert actual.window_size == expected.window_size
assert actual.iterations == expected.iterations
assert actual.random_seed == expected.random_seed
assert actual.use_lcc == expected.use_lcc
def assert_text_embedding_configs(
actual: TextEmbeddingConfig, expected: TextEmbeddingConfig
) -> None:
assert actual.batch_size == expected.batch_size
assert actual.batch_max_tokens == expected.batch_max_tokens
assert actual.target == expected.target
assert actual.names == expected.names
assert actual.strategy == expected.strategy
assert actual.model_id == expected.model_id
def assert_chunking_configs(actual: ChunkingConfig, expected: ChunkingConfig) -> None:
assert actual.size == expected.size
assert actual.overlap == expected.overlap
assert actual.group_by_columns == expected.group_by_columns
assert actual.strategy == expected.strategy
assert actual.encoding_model == expected.encoding_model
def assert_snapshots_configs(
actual: SnapshotsConfig, expected: SnapshotsConfig
) -> None:
assert actual.embeddings == expected.embeddings
assert actual.graphml == expected.graphml
assert actual.transient == expected.transient
def assert_entity_extraction_configs(
actual: EntityExtractionConfig, expected: EntityExtractionConfig
) -> None:
assert actual.prompt == expected.prompt
assert actual.entity_types == expected.entity_types
assert actual.max_gleanings == expected.max_gleanings
assert actual.strategy == expected.strategy
assert actual.encoding_model == expected.encoding_model
assert actual.model_id == expected.model_id
def assert_summarize_descriptions_configs(
actual: SummarizeDescriptionsConfig, expected: SummarizeDescriptionsConfig
) -> None:
assert actual.prompt == expected.prompt
assert actual.max_length == expected.max_length
assert actual.strategy == expected.strategy
assert actual.model_id == expected.model_id
def assert_community_reports_configs(
actual: CommunityReportsConfig, expected: CommunityReportsConfig
) -> None:
assert actual.prompt == expected.prompt
assert actual.max_length == expected.max_length
assert actual.max_input_length == expected.max_input_length
assert actual.strategy == expected.strategy
assert actual.model_id == expected.model_id
def assert_claim_extraction_configs(
actual: ClaimExtractionConfig, expected: ClaimExtractionConfig
) -> None:
assert actual.enabled == expected.enabled
assert actual.prompt == expected.prompt
assert actual.description == expected.description
assert actual.max_gleanings == expected.max_gleanings
assert actual.strategy == expected.strategy
assert actual.encoding_model == expected.encoding_model
assert actual.model_id == expected.model_id
def assert_cluster_graph_configs(
actual: ClusterGraphConfig, expected: ClusterGraphConfig
) -> None:
assert actual.max_cluster_size == expected.max_cluster_size
assert actual.use_lcc == expected.use_lcc
assert actual.seed == expected.seed
def assert_umap_configs(actual: UmapConfig, expected: UmapConfig) -> None:
assert actual.enabled == expected.enabled
def assert_local_search_configs(
actual: LocalSearchConfig, expected: LocalSearchConfig
) -> None:
assert actual.prompt == expected.prompt
assert actual.text_unit_prop == expected.text_unit_prop
assert actual.community_prop == expected.community_prop
assert (
actual.conversation_history_max_turns == expected.conversation_history_max_turns
)
assert actual.top_k_entities == expected.top_k_entities
assert actual.top_k_relationships == expected.top_k_relationships
assert actual.temperature == expected.temperature
assert actual.top_p == expected.top_p
assert actual.n == expected.n
assert actual.max_tokens == expected.max_tokens
assert actual.llm_max_tokens == expected.llm_max_tokens
def assert_global_search_configs(
actual: GlobalSearchConfig, expected: GlobalSearchConfig
) -> None:
assert actual.map_prompt == expected.map_prompt
assert actual.reduce_prompt == expected.reduce_prompt
assert actual.knowledge_prompt == expected.knowledge_prompt
assert actual.temperature == expected.temperature
assert actual.top_p == expected.top_p
assert actual.n == expected.n
assert actual.max_tokens == expected.max_tokens
assert actual.data_max_tokens == expected.data_max_tokens
assert actual.map_max_tokens == expected.map_max_tokens
assert actual.reduce_max_tokens == expected.reduce_max_tokens
assert actual.concurrency == expected.concurrency
assert actual.dynamic_search_llm == expected.dynamic_search_llm
assert actual.dynamic_search_threshold == expected.dynamic_search_threshold
assert actual.dynamic_search_keep_parent == expected.dynamic_search_keep_parent
assert actual.dynamic_search_num_repeats == expected.dynamic_search_num_repeats
assert actual.dynamic_search_use_summary == expected.dynamic_search_use_summary
assert (
actual.dynamic_search_concurrent_coroutines
== expected.dynamic_search_concurrent_coroutines
)
assert actual.dynamic_search_max_level == expected.dynamic_search_max_level
def assert_drift_search_configs(
actual: DRIFTSearchConfig, expected: DRIFTSearchConfig
) -> None:
assert actual.prompt == expected.prompt
assert actual.temperature == expected.temperature
assert actual.top_p == expected.top_p
assert actual.n == expected.n
assert actual.max_tokens == expected.max_tokens
assert actual.data_max_tokens == expected.data_max_tokens
assert actual.concurrency == expected.concurrency
assert actual.drift_k_followups == expected.drift_k_followups
assert actual.primer_folds == expected.primer_folds
assert actual.primer_llm_max_tokens == expected.primer_llm_max_tokens
assert actual.n_depth == expected.n_depth
assert actual.local_search_text_unit_prop == expected.local_search_text_unit_prop
assert actual.local_search_community_prop == expected.local_search_community_prop
assert (
actual.local_search_top_k_mapped_entities
== expected.local_search_top_k_mapped_entities
)
assert (
actual.local_search_top_k_relationships
== expected.local_search_top_k_relationships
)
assert actual.local_search_max_data_tokens == expected.local_search_max_data_tokens
assert actual.local_search_temperature == expected.local_search_temperature
assert actual.local_search_top_p == expected.local_search_top_p
assert actual.local_search_n == expected.local_search_n
assert (
actual.local_search_llm_max_gen_tokens
== expected.local_search_llm_max_gen_tokens
)
def assert_basic_search_configs(
actual: BasicSearchConfig, expected: BasicSearchConfig
) -> None:
assert actual.prompt == expected.prompt
assert actual.text_unit_prop == expected.text_unit_prop
assert (
actual.conversation_history_max_turns == expected.conversation_history_max_turns
)
assert actual.temperature == expected.temperature
assert actual.top_p == expected.top_p
assert actual.n == expected.n
assert actual.max_tokens == expected.max_tokens
assert actual.llm_max_tokens == expected.llm_max_tokens
def assert_graphrag_configs(actual: GraphRagConfig, expected: GraphRagConfig) -> None:
assert actual.root_dir == expected.root_dir
a_keys = sorted(actual.models.keys())
e_keys = sorted(expected.models.keys())
assert len(a_keys) == len(e_keys)
for a, e in zip(a_keys, e_keys, strict=False):
assert a == e
assert_language_model_configs(actual.models[a], expected.models[e])
assert_vector_store_configs(actual.vector_store, expected.vector_store)
assert_reporting_configs(actual.reporting, expected.reporting)
assert_output_configs(actual.output, expected.output)
if actual.update_index_output is not None:
assert expected.update_index_output is not None
assert_output_configs(actual.update_index_output, expected.update_index_output)
else:
assert expected.update_index_output is None
assert_cache_configs(actual.cache, expected.cache)
assert_input_configs(actual.input, expected.input)
assert_embed_graph_configs(actual.embed_graph, expected.embed_graph)
assert_text_embedding_configs(actual.embeddings, expected.embeddings)
assert_chunking_configs(actual.chunks, expected.chunks)
assert_snapshots_configs(actual.snapshots, expected.snapshots)
assert_entity_extraction_configs(
actual.entity_extraction, expected.entity_extraction
)
assert_summarize_descriptions_configs(
actual.summarize_descriptions, expected.summarize_descriptions
)
assert_community_reports_configs(
actual.community_reports, expected.community_reports
)
assert_claim_extraction_configs(actual.claim_extraction, expected.claim_extraction)
assert_cluster_graph_configs(actual.cluster_graph, expected.cluster_graph)
assert_umap_configs(actual.umap, expected.umap)
assert_local_search_configs(actual.local_search, expected.local_search)
assert_global_search_configs(actual.global_search, expected.global_search)
assert_drift_search_configs(actual.drift_search, expected.drift_search)

View File

@ -3,7 +3,7 @@
import pytest
from graphrag.utils.embeddings import create_collection_name
from graphrag.config.embeddings import create_collection_name
def test_create_collection_name():

View File

@ -7,6 +7,7 @@ from graphrag.index.workflows.compute_communities import run_workflow
from graphrag.utils.storage import load_table_from_storage
from .util import (
DEFAULT_MODEL_CONFIG,
compare_outputs,
create_test_context,
load_test_table,
@ -20,7 +21,7 @@ async def test_compute_communities():
storage=["base_relationship_edges"],
)
config = create_graphrag_config()
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
await run_workflow(
config,

View File

@ -7,6 +7,7 @@ from graphrag.index.workflows.create_base_text_units import run_workflow, workfl
from graphrag.utils.storage import load_table_from_storage
from .util import (
DEFAULT_MODEL_CONFIG,
compare_outputs,
create_test_context,
load_test_table,
@ -18,7 +19,7 @@ async def test_create_base_text_units():
context = await create_test_context()
config = create_graphrag_config()
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
# test data was created with 4o, so we need to match the encoding for chunks to be identical
config.chunks.encoding_model = "o200k_base"

View File

@ -10,6 +10,7 @@ from graphrag.index.workflows.create_final_communities import (
from graphrag.utils.storage import load_table_from_storage
from .util import (
DEFAULT_MODEL_CONFIG,
compare_outputs,
create_test_context,
load_test_table,
@ -27,7 +28,7 @@ async def test_create_final_communities():
],
)
config = create_graphrag_config()
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
await run_workflow(
config,

View File

@ -2,8 +2,6 @@
# Licensed under the MIT License
import pytest
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.enums import LLMType
@ -11,7 +9,6 @@ from graphrag.index.operations.summarize_communities.community_reports_extractor
CommunityReportResponse,
FindingModel,
)
from graphrag.index.run.derive_from_rows import ParallelizationError
from graphrag.index.workflows.create_final_community_reports import (
run_workflow,
workflow_name,
@ -19,6 +16,7 @@ from graphrag.index.workflows.create_final_community_reports import (
from graphrag.utils.storage import load_table_from_storage
from .util import (
DEFAULT_MODEL_CONFIG,
compare_outputs,
create_test_context,
load_test_table,
@ -41,12 +39,6 @@ MOCK_RESPONSES = [
)
]
MOCK_LLM_CONFIG = {
"type": LLMType.StaticResponse,
"responses": MOCK_RESPONSES,
"parse_json": True,
}
async def test_create_final_community_reports():
expected = load_test_table(workflow_name)
@ -61,10 +53,16 @@ async def test_create_final_community_reports():
]
)
config = create_graphrag_config()
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
llm_settings = config.get_language_model_config(
config.community_reports.model_id
).model_dump()
llm_settings["type"] = LLMType.StaticResponse
llm_settings["responses"] = MOCK_RESPONSES
llm_settings["parse_json"] = True
config.community_reports.strategy = {
"type": "graph_intelligence",
"llm": MOCK_LLM_CONFIG,
"llm": llm_settings,
}
await run_workflow(
@ -83,27 +81,3 @@ async def test_create_final_community_reports():
# assert a handful of mock data items to confirm they get put in the right spot
assert actual["rank"][:1][0] == 2
assert actual["rank_explanation"][:1][0] == "<rating_explanation>"
async def test_create_final_community_reports_missing_llm_throws():
context = await create_test_context(
storage=[
"create_final_nodes",
"create_final_covariates",
"create_final_relationships",
"create_final_entities",
"create_final_communities",
]
)
config = create_graphrag_config()
config.community_reports.strategy = {
"type": "graph_intelligence",
}
with pytest.raises(ParallelizationError):
await run_workflow(
config,
context,
NoopWorkflowCallbacks(),
)

View File

@ -1,13 +1,11 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
import pytest
from pandas.testing import assert_series_equal
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.enums import LLMType
from graphrag.index.run.derive_from_rows import ParallelizationError
from graphrag.index.workflows.create_final_covariates import (
run_workflow,
workflow_name,
@ -15,6 +13,7 @@ from graphrag.index.workflows.create_final_covariates import (
from graphrag.utils.storage import load_table_from_storage
from .util import (
DEFAULT_MODEL_CONFIG,
create_test_context,
load_test_table,
)
@ -25,8 +24,6 @@ MOCK_LLM_RESPONSES = [
""".strip()
]
MOCK_LLM_CONFIG = {"type": LLMType.StaticResponse, "responses": MOCK_LLM_RESPONSES}
async def test_create_final_covariates():
input = load_test_table("create_base_text_units")
@ -36,10 +33,15 @@ async def test_create_final_covariates():
storage=["create_base_text_units"],
)
config = create_graphrag_config()
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
llm_settings = config.get_language_model_config(
config.claim_extraction.model_id
).model_dump()
llm_settings["type"] = LLMType.StaticResponse
llm_settings["responses"] = MOCK_LLM_RESPONSES
config.claim_extraction.strategy = {
"type": "graph_intelligence",
"llm": MOCK_LLM_CONFIG,
"llm": llm_settings,
"claim_description": "description",
}
@ -78,22 +80,3 @@ async def test_create_final_covariates():
actual["source_text"][0]
== "According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B."
)
async def test_create_final_covariates_missing_llm_throws():
context = await create_test_context(
storage=["create_base_text_units"],
)
config = create_graphrag_config()
config.claim_extraction.strategy = {
"type": "graph_intelligence",
"claim_description": "description",
}
with pytest.raises(ParallelizationError):
await run_workflow(
config,
context,
NoopWorkflowCallbacks(),
)

View File

@ -10,6 +10,7 @@ from graphrag.index.workflows.create_final_documents import (
from graphrag.utils.storage import load_table_from_storage
from .util import (
DEFAULT_MODEL_CONFIG,
compare_outputs,
create_test_context,
load_test_table,
@ -23,7 +24,7 @@ async def test_create_final_documents():
storage=["create_base_text_units"],
)
config = create_graphrag_config()
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
await run_workflow(
config,
@ -43,7 +44,7 @@ async def test_create_final_documents_with_attribute_columns():
storage=["create_base_text_units"],
)
config = create_graphrag_config()
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
config.input.document_attribute_columns = ["title"]
await run_workflow(

View File

@ -10,6 +10,7 @@ from graphrag.index.workflows.create_final_entities import (
from graphrag.utils.storage import load_table_from_storage
from .util import (
DEFAULT_MODEL_CONFIG,
compare_outputs,
create_test_context,
load_test_table,
@ -23,7 +24,7 @@ async def test_create_final_entities():
storage=["base_entity_nodes"],
)
config = create_graphrag_config()
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
await run_workflow(
config,

View File

@ -10,6 +10,7 @@ from graphrag.index.workflows.create_final_nodes import (
from graphrag.utils.storage import load_table_from_storage
from .util import (
DEFAULT_MODEL_CONFIG,
compare_outputs,
create_test_context,
load_test_table,
@ -27,7 +28,7 @@ async def test_create_final_nodes():
],
)
config = create_graphrag_config()
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
await run_workflow(
config,

View File

@ -11,6 +11,7 @@ from graphrag.index.workflows.create_final_relationships import (
from graphrag.utils.storage import load_table_from_storage
from .util import (
DEFAULT_MODEL_CONFIG,
compare_outputs,
create_test_context,
load_test_table,
@ -24,7 +25,7 @@ async def test_create_final_relationships():
storage=["base_relationship_edges"],
)
config = create_graphrag_config()
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
await run_workflow(
config,

Some files were not shown because too many files have changed in this diff Show More