mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Refactor config (#1593)
* Refactor config
- Add new ModelConfig to represent LLM settings
- Combines LLMParameters, ParallelizationParameters, encoding_model, and async_mode
- Add top level models config that is a list of available LLM ModelConfigs
- Remove LLMConfig inheritance and delete LLMConfig
- Replace the inheritance with a model_id reference to the ModelConfig listed in the top level models config
- Remove all fallbacks and hydration logic from create_graphrag_config
- This removes the automatic env variable overrides
- Support env variables within config files using Templating
- This requires "$" to be escaped with extra "$" so ".*\\.txt$" becomes ".*\\.txt$$"
- Update init content to initialize new config file with the ModelConfig structure
* Use dict of ModelConfig instead of list
* Add model validations and unit tests
* Fix ruff checks
* Add semversioner change
* Fix unit tests
* validate root_dir in pydantic model
* Rename ModelConfig to LanguageModelConfig
* Rename ModelConfigMissingError to LanguageModelConfigMissingError
* Add validationg for unexpected API keys
* Allow skipping pydantic validation for testing/mocking purposes.
* Add default lm configs to verb tests
* smoke test
* remove config from flows to fix llm arg mapping
* Fix embedding llm arg mapping
* Remove timestamp from smoke test outputs
* Remove unused "subworkflows" smoke test properties
* Add models to smoke test configs
* Update smoke test output path
* Send logs to logs folder
* Fix output path
* Fix csv test file pattern
* Update placeholder
* Format
* Instantiate default model configs
* Fix unit tests for config defaults
* Fix migration notebook
* Remove create_pipeline_config
* Remove several unused config models
* Remove indexing embedding and input configs
* Move embeddings function to config
* Remove skip_workflows
* Remove skip embeddings in favor of explicit naming
* fix unit test spelling mistake
* self.models[model_id] is already a language model. Remove redundant casting.
* update validation errors to instruct users to rerun graphrag init
* instantiate LanguageModelConfigs with validation
* skip validation in unit tests
* update verb tests to use default model settings instead of skipping validation
* test using llm settings
* cleanup verb tests
* remove unsafe default model config
* remove the ability to skip pydantic validation
* remove None union types when default values are set
* move vector_store from embeddings to top level of config and delete resolve_paths
* update vector store settings
* fix vector store and smoke tests
* fix serializing vector_store settings
* fix vector_store usage
* fix vector_store type
* support cli overrides for loading graphrag config
* rename storage to output
* Add --force flag to init
* Remove run_id and resume, fix Drift config assignment
* Ruff
---------
Co-authored-by: Nathan Evans <github@talkswithnumbers.com>
Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
parent
47adfe16f0
commit
c644338bae
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "minor",
|
||||
"description": "Remove config inheritance, hydration, and automatic env var overlays."
|
||||
}
|
||||
@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 66,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -25,7 +25,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 67,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -37,7 +37,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -45,19 +45,20 @@
|
||||
"\n",
|
||||
"from graphrag.config.load_config import load_config\n",
|
||||
"from graphrag.config.resolve_path import resolve_paths\n",
|
||||
"from graphrag.index.create_pipeline_config import create_pipeline_config\n",
|
||||
"from graphrag.storage.factory import create_storage\n",
|
||||
"from graphrag.storage.factory import StorageFactory\n",
|
||||
"\n",
|
||||
"# This first block does some config loading, path resolution, and translation that is normally done by the CLI/API when running a full workflow\n",
|
||||
"config = load_config(Path(PROJECT_DIRECTORY))\n",
|
||||
"resolve_paths(config)\n",
|
||||
"pipeline_config = create_pipeline_config(config)\n",
|
||||
"storage = create_storage(pipeline_config.storage)"
|
||||
"storage_config = config.storage.model_dump() # type: ignore\n",
|
||||
"storage = StorageFactory().create_storage(\n",
|
||||
" storage_type=storage_config[\"type\"], kwargs=storage_config\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 69,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -68,7 +69,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 63,
|
||||
"execution_count": 70,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -97,7 +98,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 64,
|
||||
"execution_count": 71,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -108,22 +109,16 @@
|
||||
"# First we'll go through any parquet files that had model changes and update them\n",
|
||||
"# The new data model may have removed excess columns as well, but we will only make the minimal changes required for compatibility\n",
|
||||
"\n",
|
||||
"final_documents = await load_table_from_storage(\n",
|
||||
" \"create_final_documents.parquet\", storage\n",
|
||||
")\n",
|
||||
"final_text_units = await load_table_from_storage(\n",
|
||||
" \"create_final_text_units.parquet\", storage\n",
|
||||
")\n",
|
||||
"final_entities = await load_table_from_storage(\"create_final_entities.parquet\", storage)\n",
|
||||
"final_nodes = await load_table_from_storage(\"create_final_nodes.parquet\", storage)\n",
|
||||
"final_documents = await load_table_from_storage(\"create_final_documents\", storage)\n",
|
||||
"final_text_units = await load_table_from_storage(\"create_final_text_units\", storage)\n",
|
||||
"final_entities = await load_table_from_storage(\"create_final_entities\", storage)\n",
|
||||
"final_nodes = await load_table_from_storage(\"create_final_nodes\", storage)\n",
|
||||
"final_relationships = await load_table_from_storage(\n",
|
||||
" \"create_final_relationships.parquet\", storage\n",
|
||||
")\n",
|
||||
"final_communities = await load_table_from_storage(\n",
|
||||
" \"create_final_communities.parquet\", storage\n",
|
||||
" \"create_final_relationships\", storage\n",
|
||||
")\n",
|
||||
"final_communities = await load_table_from_storage(\"create_final_communities\", storage)\n",
|
||||
"final_community_reports = await load_table_from_storage(\n",
|
||||
" \"create_final_community_reports.parquet\", storage\n",
|
||||
" \"create_final_community_reports\", storage\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@ -183,44 +178,41 @@
|
||||
" parent_df, on=\"community\", how=\"left\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"await write_table_to_storage(final_documents, \"create_final_documents.parquet\", storage)\n",
|
||||
"await write_table_to_storage(final_documents, \"create_final_documents\", storage)\n",
|
||||
"await write_table_to_storage(final_text_units, \"create_final_text_units\", storage)\n",
|
||||
"await write_table_to_storage(final_entities, \"create_final_entities\", storage)\n",
|
||||
"await write_table_to_storage(final_nodes, \"create_final_nodes\", storage)\n",
|
||||
"await write_table_to_storage(final_relationships, \"create_final_relationships\", storage)\n",
|
||||
"await write_table_to_storage(final_communities, \"create_final_communities\", storage)\n",
|
||||
"await write_table_to_storage(\n",
|
||||
" final_text_units, \"create_final_text_units.parquet\", storage\n",
|
||||
")\n",
|
||||
"await write_table_to_storage(final_entities, \"create_final_entities.parquet\", storage)\n",
|
||||
"await write_table_to_storage(final_nodes, \"create_final_nodes.parquet\", storage)\n",
|
||||
"await write_table_to_storage(\n",
|
||||
" final_relationships, \"create_final_relationships.parquet\", storage\n",
|
||||
")\n",
|
||||
"await write_table_to_storage(\n",
|
||||
" final_communities, \"create_final_communities.parquet\", storage\n",
|
||||
")\n",
|
||||
"await write_table_to_storage(\n",
|
||||
" final_community_reports, \"create_final_community_reports.parquet\", storage\n",
|
||||
" final_community_reports, \"create_final_community_reports\", storage\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from graphrag.cache.factory import create_cache\n",
|
||||
"from graphrag.cache.factory import CacheFactory\n",
|
||||
"from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n",
|
||||
"from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings\n",
|
||||
"from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n",
|
||||
"\n",
|
||||
"# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n",
|
||||
"# We'll construct the context and run this function flow directly to avoid everything else\n",
|
||||
"\n",
|
||||
"workflow = next(\n",
|
||||
" (x for x in pipeline_config.workflows if x.name == \"generate_text_embeddings\"), None\n",
|
||||
")\n",
|
||||
"config = workflow.config\n",
|
||||
"text_embed = config.get(\"text_embed\", {})\n",
|
||||
"embedded_fields = config.get(\"embedded_fields\", {})\n",
|
||||
"\n",
|
||||
"embedded_fields = get_embedded_fields(config)\n",
|
||||
"text_embed = get_embedding_settings(config)\n",
|
||||
"callbacks = NoopWorkflowCallbacks()\n",
|
||||
"cache = create_cache(pipeline_config.cache, PROJECT_DIRECTORY)\n",
|
||||
"cache_config = config.cache.model_dump() # type: ignore\n",
|
||||
"cache = CacheFactory().create_cache(\n",
|
||||
" cache_type=cache_config[\"type\"], # type: ignore\n",
|
||||
" root_dir=PROJECT_DIRECTORY,\n",
|
||||
" kwargs=cache_config,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"await generate_text_embeddings(\n",
|
||||
" final_documents=None,\n",
|
||||
|
||||
@ -11,7 +11,7 @@ Backwards compatibility is not guaranteed at this time.
|
||||
import logging
|
||||
|
||||
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
|
||||
from graphrag.callbacks.factory import create_pipeline_reporter
|
||||
from graphrag.callbacks.reporting import create_pipeline_reporter
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.enums import CacheType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
@ -24,8 +24,6 @@ log = logging.getLogger(__name__)
|
||||
|
||||
async def build_index(
|
||||
config: GraphRagConfig,
|
||||
run_id: str = "",
|
||||
is_resume_run: bool = False,
|
||||
memory_profile: bool = False,
|
||||
callbacks: list[WorkflowCallbacks] | None = None,
|
||||
progress_logger: ProgressLogger | None = None,
|
||||
@ -36,10 +34,6 @@ async def build_index(
|
||||
----------
|
||||
config : GraphRagConfig
|
||||
The configuration.
|
||||
run_id : str
|
||||
The run id. Creates a output directory with this name.
|
||||
is_resume_run : bool default=False
|
||||
Whether to resume a previous index run.
|
||||
memory_profile : bool
|
||||
Whether to enable memory profiling.
|
||||
callbacks : list[WorkflowCallbacks] | None default=None
|
||||
@ -52,11 +46,7 @@ async def build_index(
|
||||
list[PipelineRunResult]
|
||||
The list of pipeline run results
|
||||
"""
|
||||
is_update_run = bool(config.update_index_storage)
|
||||
|
||||
if is_resume_run and is_update_run:
|
||||
msg = "Cannot resume and update a run at the same time."
|
||||
raise ValueError(msg)
|
||||
is_update_run = bool(config.update_index_output)
|
||||
|
||||
pipeline_cache = (
|
||||
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
|
||||
@ -78,7 +68,6 @@ async def build_index(
|
||||
cache=pipeline_cache,
|
||||
callbacks=callbacks,
|
||||
logger=progress_logger,
|
||||
run_id=run_id,
|
||||
is_update_run=is_update_run,
|
||||
):
|
||||
outputs.append(output)
|
||||
|
||||
@ -17,7 +17,7 @@ from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.llm.load_llm import load_llm
|
||||
from graphrag.logger.print_progress import PrintProgressLogger
|
||||
from graphrag.prompt_tune.defaults import MAX_TOKEN_COUNT
|
||||
from graphrag.prompt_tune.defaults import MAX_TOKEN_COUNT, PROMPT_TUNING_MODEL_ID
|
||||
from graphrag.prompt_tune.generator.community_report_rating import (
|
||||
generate_community_report_rating,
|
||||
)
|
||||
@ -95,9 +95,11 @@ async def generate_indexing_prompts(
|
||||
)
|
||||
|
||||
# Create LLM from config
|
||||
# TODO: Expose way to specify Prompt Tuning model ID through config
|
||||
default_llm_settings = config.get_language_model_config(PROMPT_TUNING_MODEL_ID)
|
||||
llm = load_llm(
|
||||
"prompt_tuning",
|
||||
config.llm,
|
||||
default_llm_settings,
|
||||
cache=None,
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
)
|
||||
@ -120,6 +122,9 @@ async def generate_indexing_prompts(
|
||||
)
|
||||
|
||||
entity_types = None
|
||||
entity_extraction_llm_settings = config.get_language_model_config(
|
||||
config.entity_extraction.model_id
|
||||
)
|
||||
if discover_entity_types:
|
||||
logger.info("Generating entity types...")
|
||||
entity_types = await generate_entity_types(
|
||||
@ -127,7 +132,7 @@ async def generate_indexing_prompts(
|
||||
domain=domain,
|
||||
persona=persona,
|
||||
docs=doc_list,
|
||||
json_mode=config.llm.model_supports_json or False,
|
||||
json_mode=entity_extraction_llm_settings.model_supports_json or False,
|
||||
)
|
||||
|
||||
logger.info("Generating entity relationship examples...")
|
||||
@ -147,7 +152,7 @@ async def generate_indexing_prompts(
|
||||
examples=examples,
|
||||
language=language,
|
||||
json_mode=False, # config.llm.model_supports_json should be used, but these prompts are used in non-json mode by the index engine
|
||||
encoding_model=config.encoding_model,
|
||||
encoding_model=entity_extraction_llm_settings.encoding_model,
|
||||
max_token_count=max_tokens,
|
||||
min_examples_required=min_examples_required,
|
||||
)
|
||||
|
||||
@ -24,12 +24,13 @@ from typing import TYPE_CHECKING, Any
|
||||
import pandas as pd
|
||||
from pydantic import validate_call
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.config.embeddings import (
|
||||
from graphrag.config.embeddings import (
|
||||
community_full_content_embedding,
|
||||
create_collection_name,
|
||||
entity_description_embedding,
|
||||
text_unit_text_embedding,
|
||||
)
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.logger.print_progress import PrintProgressLogger
|
||||
from graphrag.query.factory import (
|
||||
get_basic_search_engine,
|
||||
@ -47,7 +48,6 @@ from graphrag.query.indexer_adapters import (
|
||||
read_indexer_text_units,
|
||||
)
|
||||
from graphrag.utils.cli import redact
|
||||
from graphrag.utils.embeddings import create_collection_name
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
from graphrag.vector_stores.factory import VectorStoreFactory
|
||||
|
||||
@ -244,7 +244,7 @@ async def local_search(
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
vector_store_args = config.embeddings.vector_store
|
||||
vector_store_args = config.vector_store.model_dump()
|
||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
||||
|
||||
description_embedding_store = _get_embedding_store(
|
||||
@ -310,7 +310,7 @@ async def local_search_streaming(
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
vector_store_args = config.embeddings.vector_store
|
||||
vector_store_args = config.vector_store.model_dump()
|
||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
||||
|
||||
description_embedding_store = _get_embedding_store(
|
||||
@ -381,7 +381,7 @@ async def drift_search_streaming(
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
vector_store_args = config.embeddings.vector_store
|
||||
vector_store_args = config.vector_store.model_dump()
|
||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
||||
|
||||
description_embedding_store = _get_embedding_store(
|
||||
@ -465,7 +465,7 @@ async def drift_search(
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
vector_store_args = config.embeddings.vector_store
|
||||
vector_store_args = config.vector_store.model_dump()
|
||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
||||
|
||||
description_embedding_store = _get_embedding_store(
|
||||
@ -531,7 +531,7 @@ async def basic_search(
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
vector_store_args = config.embeddings.vector_store
|
||||
vector_store_args = config.vector_store.model_dump()
|
||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
||||
|
||||
description_embedding_store = _get_embedding_store(
|
||||
@ -576,7 +576,7 @@ async def basic_search_streaming(
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
vector_store_args = config.embeddings.vector_store
|
||||
vector_store_args = config.vector_store.model_dump()
|
||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
||||
|
||||
description_embedding_store = _get_embedding_store(
|
||||
|
||||
@ -1,45 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Create a pipeline logger."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from graphrag.callbacks.blob_workflow_callbacks import BlobWorkflowCallbacks
|
||||
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
|
||||
from graphrag.callbacks.file_workflow_callbacks import FileWorkflowCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.enums import ReportingType
|
||||
from graphrag.index.config.reporting import (
|
||||
PipelineBlobReportingConfig,
|
||||
PipelineFileReportingConfig,
|
||||
PipelineReportingConfig,
|
||||
)
|
||||
|
||||
|
||||
def create_pipeline_reporter(
|
||||
config: PipelineReportingConfig | None, root_dir: str | None
|
||||
) -> WorkflowCallbacks:
|
||||
"""Create a logger for the given pipeline config."""
|
||||
config = config or PipelineFileReportingConfig(base_dir="logs")
|
||||
|
||||
match config.type:
|
||||
case ReportingType.file:
|
||||
config = cast("PipelineFileReportingConfig", config)
|
||||
return FileWorkflowCallbacks(
|
||||
str(Path(root_dir or "") / (config.base_dir or ""))
|
||||
)
|
||||
case ReportingType.console:
|
||||
return ConsoleWorkflowCallbacks()
|
||||
case ReportingType.blob:
|
||||
config = cast("PipelineBlobReportingConfig", config)
|
||||
return BlobWorkflowCallbacks(
|
||||
config.connection_string,
|
||||
config.container_name,
|
||||
base_dir=config.base_dir,
|
||||
storage_account_blob_url=config.storage_account_blob_url,
|
||||
)
|
||||
case _:
|
||||
msg = f"Unknown reporting type: {config.type}"
|
||||
raise ValueError(msg)
|
||||
@ -5,12 +5,19 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generic, Literal, TypeVar
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Generic, Literal, TypeVar, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphrag.callbacks.blob_workflow_callbacks import BlobWorkflowCallbacks
|
||||
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
|
||||
from graphrag.callbacks.file_workflow_callbacks import FileWorkflowCallbacks
|
||||
from graphrag.config.enums import ReportingType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@ -74,3 +81,30 @@ PipelineReportingConfigTypes = (
|
||||
| PipelineConsoleReportingConfig
|
||||
| PipelineBlobReportingConfig
|
||||
)
|
||||
|
||||
|
||||
def create_pipeline_reporter(
|
||||
config: PipelineReportingConfig | None, root_dir: str | None
|
||||
) -> WorkflowCallbacks:
|
||||
"""Create a logger for the given pipeline config."""
|
||||
config = config or PipelineFileReportingConfig(base_dir="logs")
|
||||
|
||||
match config.type:
|
||||
case ReportingType.file:
|
||||
config = cast("PipelineFileReportingConfig", config)
|
||||
return FileWorkflowCallbacks(
|
||||
str(Path(root_dir or "") / (config.base_dir or ""))
|
||||
)
|
||||
case ReportingType.console:
|
||||
return ConsoleWorkflowCallbacks()
|
||||
case ReportingType.blob:
|
||||
config = cast("PipelineBlobReportingConfig", config)
|
||||
return BlobWorkflowCallbacks(
|
||||
config.connection_string,
|
||||
config.container_name,
|
||||
base_dir=config.base_dir,
|
||||
storage_account_blob_url=config.storage_account_blob_url,
|
||||
)
|
||||
case _:
|
||||
msg = f"Unknown reporting type: {config.type}"
|
||||
raise ValueError(msg)
|
||||
@ -6,7 +6,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
@ -14,7 +13,6 @@ import graphrag.api as api
|
||||
from graphrag.config.enums import CacheType
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.config.logging import enable_logging_with_config
|
||||
from graphrag.config.resolve_path import resolve_paths
|
||||
from graphrag.index.validate_config import validate_config_names
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.factory import LoggerFactory, LoggerType
|
||||
@ -66,7 +64,6 @@ def _register_signal_handlers(logger: ProgressLogger):
|
||||
def index_cli(
|
||||
root_dir: Path,
|
||||
verbose: bool,
|
||||
resume: str | None,
|
||||
memprofile: bool,
|
||||
cache: bool,
|
||||
logger: LoggerType,
|
||||
@ -76,18 +73,20 @@ def index_cli(
|
||||
output_dir: Path | None,
|
||||
):
|
||||
"""Run the pipeline with the given config."""
|
||||
config = load_config(root_dir, config_filepath)
|
||||
cli_overrides = {}
|
||||
if output_dir:
|
||||
cli_overrides["output.base_dir"] = str(output_dir)
|
||||
cli_overrides["reporting.base_dir"] = str(output_dir)
|
||||
config = load_config(root_dir, config_filepath, cli_overrides)
|
||||
|
||||
_run_index(
|
||||
config=config,
|
||||
verbose=verbose,
|
||||
resume=resume,
|
||||
memprofile=memprofile,
|
||||
cache=cache,
|
||||
logger=logger,
|
||||
dry_run=dry_run,
|
||||
skip_validation=skip_validation,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
|
||||
@ -102,51 +101,44 @@ def update_cli(
|
||||
output_dir: Path | None,
|
||||
):
|
||||
"""Run the pipeline with the given config."""
|
||||
config = load_config(root_dir, config_filepath)
|
||||
cli_overrides = {}
|
||||
if output_dir:
|
||||
cli_overrides["output.base_dir"] = str(output_dir)
|
||||
cli_overrides["reporting.base_dir"] = str(output_dir)
|
||||
config = load_config(root_dir, config_filepath, cli_overrides)
|
||||
|
||||
# Check if update storage exist, if not configure it with default values
|
||||
if not config.update_index_storage:
|
||||
from graphrag.config.defaults import STORAGE_TYPE, UPDATE_STORAGE_BASE_DIR
|
||||
from graphrag.config.models.storage_config import StorageConfig
|
||||
# Check if update output exist, if not configure it with default values
|
||||
if not config.update_index_output:
|
||||
from graphrag.config.defaults import OUTPUT_TYPE, UPDATE_OUTPUT_BASE_DIR
|
||||
from graphrag.config.models.output_config import OutputConfig
|
||||
|
||||
config.update_index_storage = StorageConfig(
|
||||
type=STORAGE_TYPE,
|
||||
base_dir=UPDATE_STORAGE_BASE_DIR,
|
||||
config.update_index_output = OutputConfig(
|
||||
type=OUTPUT_TYPE,
|
||||
base_dir=UPDATE_OUTPUT_BASE_DIR,
|
||||
)
|
||||
|
||||
_run_index(
|
||||
config=config,
|
||||
verbose=verbose,
|
||||
resume=False,
|
||||
memprofile=memprofile,
|
||||
cache=cache,
|
||||
logger=logger,
|
||||
dry_run=False,
|
||||
skip_validation=skip_validation,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
|
||||
def _run_index(
|
||||
config,
|
||||
verbose,
|
||||
resume,
|
||||
memprofile,
|
||||
cache,
|
||||
logger,
|
||||
dry_run,
|
||||
skip_validation,
|
||||
output_dir,
|
||||
):
|
||||
progress_logger = LoggerFactory().create_logger(logger)
|
||||
info, error, success = _logger(progress_logger)
|
||||
run_id = resume or time.strftime("%Y%m%d-%H%M%S")
|
||||
|
||||
config.storage.base_dir = str(output_dir) if output_dir else config.storage.base_dir
|
||||
config.reporting.base_dir = (
|
||||
str(output_dir) if output_dir else config.reporting.base_dir
|
||||
)
|
||||
resolve_paths(config, run_id)
|
||||
|
||||
if not cache:
|
||||
config.cache.type = CacheType.none
|
||||
@ -163,7 +155,7 @@ def _run_index(
|
||||
if skip_validation:
|
||||
validate_config_names(progress_logger, config)
|
||||
|
||||
info(f"Starting pipeline run for: {run_id}, {dry_run=}", verbose)
|
||||
info(f"Starting pipeline run. {dry_run=}", verbose)
|
||||
info(
|
||||
f"Using default configuration: {redact(config.model_dump())}",
|
||||
verbose,
|
||||
@ -178,8 +170,6 @@ def _run_index(
|
||||
outputs = asyncio.run(
|
||||
api.build_index(
|
||||
config=config,
|
||||
run_id=run_id,
|
||||
is_resume_run=bool(resume),
|
||||
memory_profile=memprofile,
|
||||
progress_logger=progress_logger,
|
||||
)
|
||||
|
||||
@ -29,8 +29,22 @@ from graphrag.prompts.query.local_search_system_prompt import LOCAL_SEARCH_SYSTE
|
||||
from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT
|
||||
|
||||
|
||||
def initialize_project_at(path: Path) -> None:
|
||||
"""Initialize the project at the given path."""
|
||||
def initialize_project_at(path: Path, force: bool) -> None:
|
||||
"""
|
||||
Initialize the project at the given path.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : Path
|
||||
The path at which to initialize the project.
|
||||
force : bool
|
||||
Whether to force initialization even if the project already exists.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the project already exists and force is False.
|
||||
"""
|
||||
progress_logger = LoggerFactory().create_logger(LoggerType.RICH)
|
||||
progress_logger.info(f"Initializing project at {path}") # noqa: G004
|
||||
root = Path(path)
|
||||
@ -38,7 +52,7 @@ def initialize_project_at(path: Path) -> None:
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
settings_yaml = root / "settings.yaml"
|
||||
if settings_yaml.exists():
|
||||
if settings_yaml.exists() and not force:
|
||||
msg = f"Project already initialized at {root}"
|
||||
raise ValueError(msg)
|
||||
|
||||
@ -46,7 +60,7 @@ def initialize_project_at(path: Path) -> None:
|
||||
file.write(INIT_YAML.encode(encoding="utf-8", errors="strict"))
|
||||
|
||||
dotenv = root / ".env"
|
||||
if not dotenv.exists():
|
||||
if not dotenv.exists() or force:
|
||||
with dotenv.open("wb") as file:
|
||||
file.write(INIT_DOTENV.encode(encoding="utf-8", errors="strict"))
|
||||
|
||||
@ -71,6 +85,6 @@ def initialize_project_at(path: Path) -> None:
|
||||
|
||||
for name, content in prompts.items():
|
||||
prompt_file = prompts_dir / f"{name}.txt"
|
||||
if not prompt_file.exists():
|
||||
if not prompt_file.exists() or force:
|
||||
with prompt_file.open("wb") as file:
|
||||
file.write(content.encode(encoding="utf-8", errors="strict"))
|
||||
|
||||
@ -109,11 +109,15 @@ def _initialize_cli(
|
||||
),
|
||||
),
|
||||
],
|
||||
force: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Force initialization even if the project already exists."),
|
||||
] = False,
|
||||
):
|
||||
"""Generate a default configuration file."""
|
||||
from graphrag.cli.initialize import initialize_project_at
|
||||
|
||||
initialize_project_at(path=root)
|
||||
initialize_project_at(path=root, force=force)
|
||||
|
||||
|
||||
@app.command("index")
|
||||
@ -143,9 +147,6 @@ def _index_cli(
|
||||
memprofile: Annotated[
|
||||
bool, typer.Option(help="Run the indexing pipeline with memory profiling")
|
||||
] = False,
|
||||
resume: Annotated[
|
||||
str | None, typer.Option(help="Resume a given indexing run")
|
||||
] = None,
|
||||
logger: Annotated[
|
||||
LoggerType, typer.Option(help="The progress logger to use.")
|
||||
] = LoggerType.RICH,
|
||||
@ -165,7 +166,7 @@ def _index_cli(
|
||||
output: Annotated[
|
||||
Path | None,
|
||||
typer.Option(
|
||||
help="Indexing pipeline output directory. Overrides storage.base_dir in the configuration file.",
|
||||
help="Indexing pipeline output directory. Overrides output.base_dir in the configuration file.",
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
@ -178,7 +179,6 @@ def _index_cli(
|
||||
index_cli(
|
||||
root_dir=root,
|
||||
verbose=verbose,
|
||||
resume=resume,
|
||||
memprofile=memprofile,
|
||||
cache=cache,
|
||||
logger=LoggerType(logger),
|
||||
@ -226,7 +226,7 @@ def _update_cli(
|
||||
output: Annotated[
|
||||
Path | None,
|
||||
typer.Option(
|
||||
help="Indexing pipeline output directory. Overrides storage.base_dir in the configuration file.",
|
||||
help="Indexing pipeline output directory. Overrides output.base_dir in the configuration file.",
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
@ -236,7 +236,7 @@ def _update_cli(
|
||||
"""
|
||||
Update an existing knowledge graph index.
|
||||
|
||||
Applies a default storage configuration (if not provided by config), saving the new index to the local file system in the `update_output` folder.
|
||||
Applies a default output configuration (if not provided by config), saving the new index to the local file system in the `update_output` folder.
|
||||
"""
|
||||
from graphrag.cli.index import update_cli
|
||||
|
||||
|
||||
@ -12,8 +12,6 @@ import pandas as pd
|
||||
import graphrag.api as api
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.resolve_path import resolve_paths
|
||||
from graphrag.index.create_pipeline_config import create_pipeline_config
|
||||
from graphrag.logger.print_progress import PrintProgressLogger
|
||||
from graphrag.storage.factory import StorageFactory
|
||||
from graphrag.utils.storage import load_table_from_storage, storage_has_table
|
||||
@ -36,9 +34,10 @@ def run_global_search(
|
||||
Loads index files required for global search and calls the Query API.
|
||||
"""
|
||||
root = root_dir.resolve()
|
||||
config = load_config(root, config_filepath)
|
||||
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
|
||||
resolve_paths(config)
|
||||
cli_overrides = {}
|
||||
if data_dir:
|
||||
cli_overrides["output.base_dir"] = str(data_dir)
|
||||
config = load_config(root, config_filepath, cli_overrides)
|
||||
|
||||
dataframe_dict = _resolve_output_files(
|
||||
config=config,
|
||||
@ -120,9 +119,10 @@ def run_local_search(
|
||||
Loads index files required for local search and calls the Query API.
|
||||
"""
|
||||
root = root_dir.resolve()
|
||||
config = load_config(root, config_filepath)
|
||||
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
|
||||
resolve_paths(config)
|
||||
cli_overrides = {}
|
||||
if data_dir:
|
||||
cli_overrides["output.base_dir"] = str(data_dir)
|
||||
config = load_config(root, config_filepath, cli_overrides)
|
||||
|
||||
dataframe_dict = _resolve_output_files(
|
||||
config=config,
|
||||
@ -211,9 +211,10 @@ def run_drift_search(
|
||||
Loads index files required for local search and calls the Query API.
|
||||
"""
|
||||
root = root_dir.resolve()
|
||||
config = load_config(root, config_filepath)
|
||||
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
|
||||
resolve_paths(config)
|
||||
cli_overrides = {}
|
||||
if data_dir:
|
||||
cli_overrides["output.base_dir"] = str(data_dir)
|
||||
config = load_config(root, config_filepath, cli_overrides)
|
||||
|
||||
dataframe_dict = _resolve_output_files(
|
||||
config=config,
|
||||
@ -296,9 +297,10 @@ def run_basic_search(
|
||||
Loads index files required for basic search and calls the Query API.
|
||||
"""
|
||||
root = root_dir.resolve()
|
||||
config = load_config(root, config_filepath)
|
||||
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
|
||||
resolve_paths(config)
|
||||
cli_overrides = {}
|
||||
if data_dir:
|
||||
cli_overrides["output.base_dir"] = str(data_dir)
|
||||
config = load_config(root, config_filepath, cli_overrides)
|
||||
|
||||
dataframe_dict = _resolve_output_files(
|
||||
config=config,
|
||||
@ -352,10 +354,9 @@ def _resolve_output_files(
|
||||
) -> dict[str, pd.DataFrame]:
|
||||
"""Read indexing output files to a dataframe dict."""
|
||||
dataframe_dict = {}
|
||||
pipeline_config = create_pipeline_config(config)
|
||||
storage_config = pipeline_config.storage.model_dump() # type: ignore
|
||||
output_config = config.output.model_dump() # type: ignore
|
||||
storage_obj = StorageFactory().create_storage(
|
||||
storage_type=storage_config["type"], kwargs=storage_config
|
||||
storage_type=output_config["type"], kwargs=output_config
|
||||
)
|
||||
for name in output_list:
|
||||
df_value = asyncio.run(load_table_from_storage(name=name, storage=storage_obj))
|
||||
|
||||
@ -1,179 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Load a GraphRagConfiguration from a file."""
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
|
||||
_default_config_files = ["settings.yaml", "settings.yml", "settings.json"]
|
||||
|
||||
|
||||
def search_for_config_in_root_dir(root: str | Path) -> Path | None:
|
||||
"""Resolve the config path from the given root directory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
root : str | Path
|
||||
The path to the root directory containing the config file.
|
||||
Searches for a default config file (settings.{yaml,yml,json}).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path | None
|
||||
returns a Path if there is a config in the root directory
|
||||
Otherwise returns None.
|
||||
"""
|
||||
root = Path(root)
|
||||
|
||||
if not root.is_dir():
|
||||
msg = f"Invalid config path: {root} is not a directory"
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
for file in _default_config_files:
|
||||
if (root / file).is_file():
|
||||
return root / file
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ConfigFileLoader(ABC):
|
||||
"""Base class for loading a configuration from a file."""
|
||||
|
||||
@abstractmethod
|
||||
def load_config(self, config_path: str | Path) -> GraphRagConfig:
|
||||
"""Load configuration from a file."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ConfigYamlLoader(ConfigFileLoader):
|
||||
"""Load a configuration from a yaml file."""
|
||||
|
||||
def load_config(self, config_path: str | Path) -> GraphRagConfig:
|
||||
"""Load a configuration from a yaml file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_path : str | Path
|
||||
The path to the yaml file to load.
|
||||
|
||||
Returns
|
||||
-------
|
||||
GraphRagConfig
|
||||
The loaded configuration.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the file extension is not .yaml or .yml.
|
||||
FileNotFoundError
|
||||
If the config file is not found.
|
||||
"""
|
||||
config_path = Path(config_path)
|
||||
if config_path.suffix not in [".yaml", ".yml"]:
|
||||
msg = f"Invalid file extension for loading yaml config from: {config_path!s}. Expected .yaml or .yml"
|
||||
raise ValueError(msg)
|
||||
root_dir = str(config_path.parent)
|
||||
if not config_path.is_file():
|
||||
msg = f"Config file not found: {config_path}"
|
||||
raise FileNotFoundError(msg)
|
||||
with config_path.open("rb") as file:
|
||||
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
|
||||
return create_graphrag_config(data, root_dir)
|
||||
|
||||
|
||||
class ConfigJsonLoader(ConfigFileLoader):
|
||||
"""Load a configuration from a json file."""
|
||||
|
||||
def load_config(self, config_path: str | Path) -> GraphRagConfig:
|
||||
"""Load a configuration from a json file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_path : str | Path
|
||||
The path to the json file to load.
|
||||
|
||||
Returns
|
||||
-------
|
||||
GraphRagConfig
|
||||
The loaded configuration.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the file extension is not .json.
|
||||
FileNotFoundError
|
||||
If the config file is not found.
|
||||
"""
|
||||
config_path = Path(config_path)
|
||||
root_dir = str(config_path.parent)
|
||||
if config_path.suffix != ".json":
|
||||
msg = f"Invalid file extension for loading json config from: {config_path!s}. Expected .json"
|
||||
raise ValueError(msg)
|
||||
if not config_path.is_file():
|
||||
msg = f"Config file not found: {config_path}"
|
||||
raise FileNotFoundError(msg)
|
||||
with config_path.open("rb") as file:
|
||||
data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
|
||||
return create_graphrag_config(data, root_dir)
|
||||
|
||||
|
||||
def get_config_file_loader(config_path: str | Path) -> ConfigFileLoader:
|
||||
"""Config File Loader Factory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_path : str | Path
|
||||
The path to the config file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ConfigFileLoader
|
||||
The config file loader for the provided config file.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the config file extension is not supported.
|
||||
"""
|
||||
config_path = Path(config_path)
|
||||
ext = config_path.suffix
|
||||
match ext:
|
||||
case ".yaml" | ".yml":
|
||||
return ConfigYamlLoader()
|
||||
case ".json":
|
||||
return ConfigJsonLoader()
|
||||
case _:
|
||||
msg = f"Unsupported config file extension: {ext}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def load_config_from_file(config_path: str | Path) -> GraphRagConfig:
|
||||
"""Load a configuration from a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_path : str | Path
|
||||
The path to the configuration file.
|
||||
Supports .yaml, .yml, and .json config files.
|
||||
|
||||
Returns
|
||||
-------
|
||||
GraphRagConfig
|
||||
The loaded configuration.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the file extension is not supported.
|
||||
FileNotFoundError
|
||||
If the config file is not found.
|
||||
"""
|
||||
loader = get_config_file_loader(config_path)
|
||||
return loader.load_config(config_path)
|
||||
@ -3,777 +3,41 @@
|
||||
|
||||
"""Parameterization settings for the default configuration, loaded from environment variables."""
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from environs import Env
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.enums import (
|
||||
AsyncType,
|
||||
CacheType,
|
||||
InputFileType,
|
||||
InputType,
|
||||
LLMType,
|
||||
ReportingType,
|
||||
StorageType,
|
||||
TextEmbeddingTarget,
|
||||
)
|
||||
from graphrag.config.environment_reader import EnvironmentReader
|
||||
from graphrag.config.errors import (
|
||||
ApiKeyMissingError,
|
||||
AzureApiBaseMissingError,
|
||||
AzureDeploymentNameMissingError,
|
||||
)
|
||||
from graphrag.config.models.basic_search_config import BasicSearchConfig
|
||||
from graphrag.config.models.cache_config import CacheConfig
|
||||
from graphrag.config.models.chunking_config import ChunkingConfig, ChunkStrategyType
|
||||
from graphrag.config.models.claim_extraction_config import ClaimExtractionConfig
|
||||
from graphrag.config.models.cluster_graph_config import ClusterGraphConfig
|
||||
from graphrag.config.models.community_reports_config import CommunityReportsConfig
|
||||
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
|
||||
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
|
||||
from graphrag.config.models.entity_extraction_config import EntityExtractionConfig
|
||||
from graphrag.config.models.global_search_config import GlobalSearchConfig
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.config.models.llm_parameters import LLMParameters
|
||||
from graphrag.config.models.local_search_config import LocalSearchConfig
|
||||
from graphrag.config.models.parallelization_parameters import ParallelizationParameters
|
||||
from graphrag.config.models.reporting_config import ReportingConfig
|
||||
from graphrag.config.models.snapshots_config import SnapshotsConfig
|
||||
from graphrag.config.models.storage_config import StorageConfig
|
||||
from graphrag.config.models.summarize_descriptions_config import (
|
||||
SummarizeDescriptionsConfig,
|
||||
)
|
||||
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
|
||||
from graphrag.config.models.umap_config import UmapConfig
|
||||
from graphrag.config.read_dotenv import read_dotenv
|
||||
|
||||
|
||||
def create_graphrag_config(
|
||||
values: dict[str, Any] | None = None, root_dir: str | None = None
|
||||
values: dict[str, Any] | None = None,
|
||||
root_dir: str | None = None,
|
||||
) -> GraphRagConfig:
|
||||
"""Load Configuration Parameters from a dictionary."""
|
||||
"""Load Configuration Parameters from a dictionary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
values : dict[str, Any] | None
|
||||
Dictionary of configuration values to pass into pydantic model.
|
||||
root_dir : str | None
|
||||
Root directory for the project.
|
||||
skip_validation : bool
|
||||
Skip pydantic model validation of the configuration.
|
||||
This is useful for testing and mocking purposes but
|
||||
should not be used in the core code or API.
|
||||
|
||||
Returns
|
||||
-------
|
||||
GraphRagConfig
|
||||
The configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValidationError
|
||||
If the configuration values do not satisfy pydantic validation.
|
||||
"""
|
||||
values = values or {}
|
||||
root_dir = root_dir or str(Path.cwd())
|
||||
env = _make_env(root_dir)
|
||||
_token_replace(cast("dict", values))
|
||||
|
||||
reader = EnvironmentReader(env)
|
||||
|
||||
def hydrate_async_type(input: dict[str, Any], base: AsyncType) -> AsyncType:
|
||||
value = input.get(Fragment.async_mode)
|
||||
return AsyncType(value) if value else base
|
||||
|
||||
def hydrate_llm_params(
|
||||
config: dict[str, Any], base: LLMParameters
|
||||
) -> LLMParameters:
|
||||
with reader.use(config.get("llm")):
|
||||
llm_type = reader.str(Fragment.type)
|
||||
llm_type = LLMType(llm_type) if llm_type else base.type
|
||||
api_key = reader.str(Fragment.api_key) or base.api_key
|
||||
api_base = reader.str(Fragment.api_base) or base.api_base
|
||||
audience = reader.str(Fragment.audience) or base.audience
|
||||
deployment_name = (
|
||||
reader.str(Fragment.deployment_name) or base.deployment_name
|
||||
)
|
||||
encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model
|
||||
|
||||
if api_key is None and not _is_azure(llm_type):
|
||||
raise ApiKeyMissingError
|
||||
if _is_azure(llm_type):
|
||||
if api_base is None:
|
||||
raise AzureApiBaseMissingError
|
||||
if deployment_name is None:
|
||||
raise AzureDeploymentNameMissingError
|
||||
|
||||
sleep_on_rate_limit = reader.bool(Fragment.sleep_recommendation)
|
||||
if sleep_on_rate_limit is None:
|
||||
sleep_on_rate_limit = base.sleep_on_rate_limit_recommendation
|
||||
|
||||
return LLMParameters(
|
||||
api_key=api_key,
|
||||
type=llm_type,
|
||||
api_base=api_base,
|
||||
api_version=reader.str(Fragment.api_version) or base.api_version,
|
||||
organization=reader.str("organization") or base.organization,
|
||||
proxy=reader.str("proxy") or base.proxy,
|
||||
model=reader.str("model") or base.model,
|
||||
encoding_model=encoding_model,
|
||||
max_tokens=reader.int(Fragment.max_tokens) or base.max_tokens,
|
||||
temperature=reader.float(Fragment.temperature) or base.temperature,
|
||||
top_p=reader.float(Fragment.top_p) or base.top_p,
|
||||
n=reader.int(Fragment.n) or base.n,
|
||||
model_supports_json=reader.bool(Fragment.model_supports_json)
|
||||
or base.model_supports_json,
|
||||
request_timeout=reader.float(Fragment.request_timeout)
|
||||
or base.request_timeout,
|
||||
audience=audience,
|
||||
deployment_name=deployment_name,
|
||||
tokens_per_minute=reader.int("tokens_per_minute", Fragment.tpm)
|
||||
or base.tokens_per_minute,
|
||||
requests_per_minute=reader.int("requests_per_minute", Fragment.rpm)
|
||||
or base.requests_per_minute,
|
||||
max_retries=reader.int(Fragment.max_retries) or base.max_retries,
|
||||
max_retry_wait=reader.float(Fragment.max_retry_wait)
|
||||
or base.max_retry_wait,
|
||||
sleep_on_rate_limit_recommendation=sleep_on_rate_limit,
|
||||
concurrent_requests=reader.int(Fragment.concurrent_requests)
|
||||
or base.concurrent_requests,
|
||||
)
|
||||
|
||||
def hydrate_embeddings_params(
|
||||
config: dict[str, Any], base: LLMParameters
|
||||
) -> LLMParameters:
|
||||
with reader.use(config.get("llm")):
|
||||
api_type = reader.str(Fragment.type) or defs.EMBEDDING_TYPE
|
||||
api_type = LLMType(api_type) if api_type else defs.LLM_TYPE
|
||||
api_key = reader.str(Fragment.api_key) or base.api_key
|
||||
|
||||
# Account for various permutations of config settings such as:
|
||||
# - same api_bases for LLM and embeddings (both Azure)
|
||||
# - different api_bases for LLM and embeddings (both Azure)
|
||||
# - LLM uses Azure OpenAI, while embeddings uses base OpenAI (this one is important)
|
||||
# - LLM uses Azure OpenAI, while embeddings uses third-party OpenAI-like API
|
||||
api_base = (
|
||||
reader.str(Fragment.api_base) or base.api_base
|
||||
if _is_azure(api_type)
|
||||
else reader.str(Fragment.api_base)
|
||||
)
|
||||
api_version = (
|
||||
reader.str(Fragment.api_version) or base.api_version
|
||||
if _is_azure(api_type)
|
||||
else reader.str(Fragment.api_version)
|
||||
)
|
||||
api_organization = reader.str("organization") or base.organization
|
||||
api_proxy = reader.str("proxy") or base.proxy
|
||||
audience = reader.str(Fragment.audience) or base.audience
|
||||
deployment_name = reader.str(Fragment.deployment_name)
|
||||
encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model
|
||||
|
||||
if api_key is None and not _is_azure(api_type):
|
||||
raise ApiKeyMissingError(embedding=True)
|
||||
if _is_azure(api_type):
|
||||
if api_base is None:
|
||||
raise AzureApiBaseMissingError(embedding=True)
|
||||
if deployment_name is None:
|
||||
raise AzureDeploymentNameMissingError(embedding=True)
|
||||
|
||||
sleep_on_rate_limit = reader.bool(Fragment.sleep_recommendation)
|
||||
if sleep_on_rate_limit is None:
|
||||
sleep_on_rate_limit = base.sleep_on_rate_limit_recommendation
|
||||
|
||||
return LLMParameters(
|
||||
api_key=api_key,
|
||||
type=api_type,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
organization=api_organization,
|
||||
proxy=api_proxy,
|
||||
model=reader.str(Fragment.model) or defs.EMBEDDING_MODEL,
|
||||
encoding_model=encoding_model,
|
||||
request_timeout=reader.float(Fragment.request_timeout)
|
||||
or defs.LLM_REQUEST_TIMEOUT,
|
||||
audience=audience,
|
||||
deployment_name=deployment_name,
|
||||
tokens_per_minute=reader.int("tokens_per_minute", Fragment.tpm)
|
||||
or defs.LLM_TOKENS_PER_MINUTE,
|
||||
requests_per_minute=reader.int("requests_per_minute", Fragment.rpm)
|
||||
or defs.LLM_REQUESTS_PER_MINUTE,
|
||||
max_retries=reader.int(Fragment.max_retries) or defs.LLM_MAX_RETRIES,
|
||||
max_retry_wait=reader.float(Fragment.max_retry_wait)
|
||||
or defs.LLM_MAX_RETRY_WAIT,
|
||||
sleep_on_rate_limit_recommendation=sleep_on_rate_limit,
|
||||
concurrent_requests=reader.int(Fragment.concurrent_requests)
|
||||
or defs.LLM_CONCURRENT_REQUESTS,
|
||||
)
|
||||
|
||||
def hydrate_parallelization_params(
|
||||
config: dict[str, Any], base: ParallelizationParameters
|
||||
) -> ParallelizationParameters:
|
||||
with reader.use(config.get("parallelization")):
|
||||
return ParallelizationParameters(
|
||||
num_threads=reader.int("num_threads", Fragment.thread_count)
|
||||
or base.num_threads,
|
||||
stagger=reader.float("stagger", Fragment.thread_stagger)
|
||||
or base.stagger,
|
||||
)
|
||||
|
||||
fallback_oai_key = env("OPENAI_API_KEY", env("AZURE_OPENAI_API_KEY", None))
|
||||
fallback_oai_org = env("OPENAI_ORG_ID", None)
|
||||
fallback_oai_base = env("OPENAI_BASE_URL", None)
|
||||
fallback_oai_version = env("OPENAI_API_VERSION", None)
|
||||
|
||||
with reader.envvar_prefix(Section.graphrag), reader.use(values):
|
||||
async_mode = reader.str(Fragment.async_mode)
|
||||
async_mode = AsyncType(async_mode) if async_mode else defs.ASYNC_MODE
|
||||
|
||||
fallback_oai_key = reader.str(Fragment.api_key) or fallback_oai_key
|
||||
fallback_oai_org = reader.str(Fragment.api_organization) or fallback_oai_org
|
||||
fallback_oai_base = reader.str(Fragment.api_base) or fallback_oai_base
|
||||
fallback_oai_version = reader.str(Fragment.api_version) or fallback_oai_version
|
||||
fallback_oai_proxy = reader.str(Fragment.api_proxy)
|
||||
global_encoding_model = (
|
||||
reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL
|
||||
)
|
||||
|
||||
with reader.envvar_prefix(Section.llm):
|
||||
with reader.use(values.get("llm")):
|
||||
llm_type = reader.str(Fragment.type)
|
||||
llm_type = LLMType(llm_type) if llm_type else defs.LLM_TYPE
|
||||
api_key = reader.str(Fragment.api_key) or fallback_oai_key
|
||||
api_organization = (
|
||||
reader.str(Fragment.api_organization) or fallback_oai_org
|
||||
)
|
||||
api_base = reader.str(Fragment.api_base) or fallback_oai_base
|
||||
api_version = reader.str(Fragment.api_version) or fallback_oai_version
|
||||
api_proxy = reader.str(Fragment.api_proxy) or fallback_oai_proxy
|
||||
audience = reader.str(Fragment.audience)
|
||||
deployment_name = reader.str(Fragment.deployment_name)
|
||||
encoding_model = (
|
||||
reader.str(Fragment.encoding_model) or global_encoding_model
|
||||
)
|
||||
|
||||
if api_key is None and not _is_azure(llm_type):
|
||||
raise ApiKeyMissingError
|
||||
if _is_azure(llm_type):
|
||||
if api_base is None:
|
||||
raise AzureApiBaseMissingError
|
||||
if deployment_name is None:
|
||||
raise AzureDeploymentNameMissingError
|
||||
|
||||
sleep_on_rate_limit = reader.bool(Fragment.sleep_recommendation)
|
||||
if sleep_on_rate_limit is None:
|
||||
sleep_on_rate_limit = defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION
|
||||
|
||||
llm_model = LLMParameters(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
organization=api_organization,
|
||||
proxy=api_proxy,
|
||||
type=llm_type,
|
||||
model=reader.str(Fragment.model) or defs.LLM_MODEL,
|
||||
encoding_model=encoding_model,
|
||||
max_tokens=reader.int(Fragment.max_tokens) or defs.LLM_MAX_TOKENS,
|
||||
temperature=reader.float(Fragment.temperature)
|
||||
or defs.LLM_TEMPERATURE,
|
||||
top_p=reader.float(Fragment.top_p) or defs.LLM_TOP_P,
|
||||
n=reader.int(Fragment.n) or defs.LLM_N,
|
||||
model_supports_json=reader.bool(Fragment.model_supports_json),
|
||||
request_timeout=reader.float(Fragment.request_timeout)
|
||||
or defs.LLM_REQUEST_TIMEOUT,
|
||||
audience=audience,
|
||||
deployment_name=deployment_name,
|
||||
tokens_per_minute=reader.int(Fragment.tpm)
|
||||
or defs.LLM_TOKENS_PER_MINUTE,
|
||||
requests_per_minute=reader.int(Fragment.rpm)
|
||||
or defs.LLM_REQUESTS_PER_MINUTE,
|
||||
max_retries=reader.int(Fragment.max_retries)
|
||||
or defs.LLM_MAX_RETRIES,
|
||||
max_retry_wait=reader.float(Fragment.max_retry_wait)
|
||||
or defs.LLM_MAX_RETRY_WAIT,
|
||||
sleep_on_rate_limit_recommendation=sleep_on_rate_limit,
|
||||
concurrent_requests=reader.int(Fragment.concurrent_requests)
|
||||
or defs.LLM_CONCURRENT_REQUESTS,
|
||||
)
|
||||
with reader.use(values.get("parallelization")):
|
||||
llm_parallelization_model = ParallelizationParameters(
|
||||
stagger=reader.float("stagger", Fragment.thread_stagger)
|
||||
or defs.PARALLELIZATION_STAGGER,
|
||||
num_threads=reader.int("num_threads", Fragment.thread_count)
|
||||
or defs.PARALLELIZATION_NUM_THREADS,
|
||||
)
|
||||
embeddings_config = values.get("embeddings") or {}
|
||||
with reader.envvar_prefix(Section.embedding), reader.use(embeddings_config):
|
||||
embeddings_target = reader.str("target")
|
||||
# TODO: remove the type ignore annotations below once the new config engine has been refactored
|
||||
embeddings_model = TextEmbeddingConfig(
|
||||
llm=hydrate_embeddings_params(embeddings_config, llm_model), # type: ignore
|
||||
parallelization=hydrate_parallelization_params(
|
||||
embeddings_config, # type: ignore
|
||||
llm_parallelization_model, # type: ignore
|
||||
),
|
||||
vector_store=embeddings_config.get("vector_store", None),
|
||||
async_mode=hydrate_async_type(embeddings_config, async_mode), # type: ignore
|
||||
target=(
|
||||
TextEmbeddingTarget(embeddings_target)
|
||||
if embeddings_target
|
||||
else defs.EMBEDDING_TARGET
|
||||
),
|
||||
batch_size=reader.int("batch_size") or defs.EMBEDDING_BATCH_SIZE,
|
||||
batch_max_tokens=reader.int("batch_max_tokens")
|
||||
or defs.EMBEDDING_BATCH_MAX_TOKENS,
|
||||
skip=reader.list("skip") or [],
|
||||
)
|
||||
with (
|
||||
reader.envvar_prefix(Section.node2vec),
|
||||
reader.use(values.get("embed_graph")),
|
||||
):
|
||||
use_lcc = reader.bool("use_lcc")
|
||||
embed_graph_model = EmbedGraphConfig(
|
||||
enabled=reader.bool(Fragment.enabled) or defs.NODE2VEC_ENABLED,
|
||||
dimensions=reader.int("dimensions") or defs.NODE2VEC_DIMENSIONS,
|
||||
num_walks=reader.int("num_walks") or defs.NODE2VEC_NUM_WALKS,
|
||||
walk_length=reader.int("walk_length") or defs.NODE2VEC_WALK_LENGTH,
|
||||
window_size=reader.int("window_size") or defs.NODE2VEC_WINDOW_SIZE,
|
||||
iterations=reader.int("iterations") or defs.NODE2VEC_ITERATIONS,
|
||||
random_seed=reader.int("random_seed") or defs.NODE2VEC_RANDOM_SEED,
|
||||
use_lcc=use_lcc if use_lcc is not None else defs.USE_LCC,
|
||||
)
|
||||
with reader.envvar_prefix(Section.input), reader.use(values.get("input")):
|
||||
input_type = reader.str("type")
|
||||
file_type = reader.str(Fragment.file_type)
|
||||
input_model = InputConfig(
|
||||
file_type=(
|
||||
InputFileType(file_type) if file_type else defs.INPUT_FILE_TYPE
|
||||
),
|
||||
type=(InputType(input_type) if input_type else defs.INPUT_TYPE),
|
||||
encoding=reader.str("file_encoding", Fragment.encoding)
|
||||
or defs.INPUT_FILE_ENCODING,
|
||||
base_dir=reader.str(Fragment.base_dir) or defs.INPUT_BASE_DIR,
|
||||
file_pattern=reader.str("file_pattern")
|
||||
or (
|
||||
defs.INPUT_TEXT_PATTERN
|
||||
if file_type == InputFileType.text
|
||||
else defs.INPUT_CSV_PATTERN
|
||||
),
|
||||
source_column=reader.str("source_column"),
|
||||
timestamp_column=reader.str("timestamp_column"),
|
||||
timestamp_format=reader.str("timestamp_format"),
|
||||
text_column=reader.str("text_column") or defs.INPUT_TEXT_COLUMN,
|
||||
title_column=reader.str("title_column"),
|
||||
document_attribute_columns=reader.list("document_attribute_columns")
|
||||
or [],
|
||||
connection_string=reader.str(Fragment.conn_string),
|
||||
storage_account_blob_url=reader.str(Fragment.storage_account_blob_url),
|
||||
container_name=reader.str(Fragment.container_name),
|
||||
)
|
||||
with reader.envvar_prefix(Section.cache), reader.use(values.get("cache")):
|
||||
c_type = reader.str(Fragment.type)
|
||||
cache_model = CacheConfig(
|
||||
type=CacheType(c_type) if c_type else defs.CACHE_TYPE,
|
||||
connection_string=reader.str(Fragment.conn_string),
|
||||
storage_account_blob_url=reader.str(Fragment.storage_account_blob_url),
|
||||
container_name=reader.str(Fragment.container_name),
|
||||
base_dir=reader.str(Fragment.base_dir) or defs.CACHE_BASE_DIR,
|
||||
cosmosdb_account_url=reader.str(Fragment.cosmosdb_account_url),
|
||||
)
|
||||
with (
|
||||
reader.envvar_prefix(Section.reporting),
|
||||
reader.use(values.get("reporting")),
|
||||
):
|
||||
r_type = reader.str(Fragment.type)
|
||||
reporting_model = ReportingConfig(
|
||||
type=ReportingType(r_type) if r_type else defs.REPORTING_TYPE,
|
||||
connection_string=reader.str(Fragment.conn_string),
|
||||
storage_account_blob_url=reader.str(Fragment.storage_account_blob_url),
|
||||
container_name=reader.str(Fragment.container_name),
|
||||
base_dir=reader.str(Fragment.base_dir) or defs.REPORTING_BASE_DIR,
|
||||
)
|
||||
with reader.envvar_prefix(Section.storage), reader.use(values.get("storage")):
|
||||
s_type = reader.str(Fragment.type)
|
||||
storage_model = StorageConfig(
|
||||
type=StorageType(s_type) if s_type else defs.STORAGE_TYPE,
|
||||
connection_string=reader.str(Fragment.conn_string),
|
||||
storage_account_blob_url=reader.str(Fragment.storage_account_blob_url),
|
||||
container_name=reader.str(Fragment.container_name),
|
||||
base_dir=reader.str(Fragment.base_dir) or defs.STORAGE_BASE_DIR,
|
||||
cosmosdb_account_url=reader.str(Fragment.cosmosdb_account_url),
|
||||
)
|
||||
|
||||
with (
|
||||
reader.envvar_prefix(Section.update_index_storage),
|
||||
reader.use(values.get("update_index_storage")),
|
||||
):
|
||||
s_type = reader.str(Fragment.type)
|
||||
if s_type:
|
||||
update_index_storage_model = StorageConfig(
|
||||
type=StorageType(s_type) if s_type else defs.STORAGE_TYPE,
|
||||
connection_string=reader.str(Fragment.conn_string),
|
||||
storage_account_blob_url=reader.str(
|
||||
Fragment.storage_account_blob_url
|
||||
),
|
||||
container_name=reader.str(Fragment.container_name),
|
||||
base_dir=reader.str(Fragment.base_dir)
|
||||
or defs.UPDATE_STORAGE_BASE_DIR,
|
||||
)
|
||||
else:
|
||||
update_index_storage_model = None
|
||||
with reader.envvar_prefix(Section.chunk), reader.use(values.get("chunks")):
|
||||
group_by_columns = reader.list("group_by_columns", "BY_COLUMNS")
|
||||
if group_by_columns is None:
|
||||
group_by_columns = defs.CHUNK_GROUP_BY_COLUMNS
|
||||
encoding_model = (
|
||||
reader.str(Fragment.encoding_model) or global_encoding_model
|
||||
)
|
||||
strategy = reader.str("strategy")
|
||||
chunks_model = ChunkingConfig(
|
||||
size=reader.int("size") or defs.CHUNK_SIZE,
|
||||
overlap=reader.int("overlap") or defs.CHUNK_OVERLAP,
|
||||
group_by_columns=group_by_columns,
|
||||
encoding_model=encoding_model,
|
||||
strategy=ChunkStrategyType(strategy)
|
||||
if strategy
|
||||
else ChunkStrategyType.tokens,
|
||||
)
|
||||
with (
|
||||
reader.envvar_prefix(Section.snapshot),
|
||||
reader.use(values.get("snapshots")),
|
||||
):
|
||||
snapshots_model = SnapshotsConfig(
|
||||
graphml=reader.bool("graphml") or defs.SNAPSHOTS_GRAPHML,
|
||||
embeddings=reader.bool("embeddings") or defs.SNAPSHOTS_EMBEDDINGS,
|
||||
transient=reader.bool("transient") or defs.SNAPSHOTS_TRANSIENT,
|
||||
)
|
||||
with reader.envvar_prefix(Section.umap), reader.use(values.get("umap")):
|
||||
umap_model = UmapConfig(
|
||||
enabled=reader.bool(Fragment.enabled) or defs.UMAP_ENABLED,
|
||||
)
|
||||
|
||||
entity_extraction_config = values.get("entity_extraction") or {}
|
||||
with (
|
||||
reader.envvar_prefix(Section.entity_extraction),
|
||||
reader.use(entity_extraction_config),
|
||||
):
|
||||
max_gleanings = reader.int(Fragment.max_gleanings)
|
||||
max_gleanings = (
|
||||
max_gleanings
|
||||
if max_gleanings is not None
|
||||
else defs.ENTITY_EXTRACTION_MAX_GLEANINGS
|
||||
)
|
||||
encoding_model = (
|
||||
reader.str(Fragment.encoding_model) or global_encoding_model
|
||||
)
|
||||
|
||||
entity_extraction_model = EntityExtractionConfig(
|
||||
llm=hydrate_llm_params(entity_extraction_config, llm_model),
|
||||
parallelization=hydrate_parallelization_params(
|
||||
entity_extraction_config, llm_parallelization_model
|
||||
),
|
||||
async_mode=hydrate_async_type(entity_extraction_config, async_mode),
|
||||
entity_types=reader.list("entity_types")
|
||||
or defs.ENTITY_EXTRACTION_ENTITY_TYPES,
|
||||
max_gleanings=max_gleanings,
|
||||
prompt=reader.str("prompt", Fragment.prompt_file),
|
||||
strategy=entity_extraction_config.get("strategy"),
|
||||
encoding_model=encoding_model,
|
||||
)
|
||||
|
||||
claim_extraction_config = values.get("claim_extraction") or {}
|
||||
with (
|
||||
reader.envvar_prefix(Section.claim_extraction),
|
||||
reader.use(claim_extraction_config),
|
||||
):
|
||||
max_gleanings = reader.int(Fragment.max_gleanings)
|
||||
max_gleanings = (
|
||||
max_gleanings if max_gleanings is not None else defs.CLAIM_MAX_GLEANINGS
|
||||
)
|
||||
encoding_model = (
|
||||
reader.str(Fragment.encoding_model) or global_encoding_model
|
||||
)
|
||||
claim_extraction_model = ClaimExtractionConfig(
|
||||
enabled=reader.bool(Fragment.enabled) or defs.CLAIM_EXTRACTION_ENABLED,
|
||||
llm=hydrate_llm_params(claim_extraction_config, llm_model),
|
||||
parallelization=hydrate_parallelization_params(
|
||||
claim_extraction_config, llm_parallelization_model
|
||||
),
|
||||
async_mode=hydrate_async_type(claim_extraction_config, async_mode),
|
||||
description=reader.str("description") or defs.CLAIM_DESCRIPTION,
|
||||
prompt=reader.str("prompt", Fragment.prompt_file),
|
||||
max_gleanings=max_gleanings,
|
||||
encoding_model=encoding_model,
|
||||
)
|
||||
|
||||
community_report_config = values.get("community_reports") or {}
|
||||
with (
|
||||
reader.envvar_prefix(Section.community_reports),
|
||||
reader.use(community_report_config),
|
||||
):
|
||||
community_reports_model = CommunityReportsConfig(
|
||||
llm=hydrate_llm_params(community_report_config, llm_model),
|
||||
parallelization=hydrate_parallelization_params(
|
||||
community_report_config, llm_parallelization_model
|
||||
),
|
||||
async_mode=hydrate_async_type(community_report_config, async_mode),
|
||||
prompt=reader.str("prompt", Fragment.prompt_file),
|
||||
max_length=reader.int(Fragment.max_length)
|
||||
or defs.COMMUNITY_REPORT_MAX_LENGTH,
|
||||
max_input_length=reader.int("max_input_length")
|
||||
or defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH,
|
||||
)
|
||||
|
||||
summarize_description_config = values.get("summarize_descriptions") or {}
|
||||
with (
|
||||
reader.envvar_prefix(Section.summarize_descriptions),
|
||||
reader.use(values.get("summarize_descriptions")),
|
||||
):
|
||||
summarize_descriptions_model = SummarizeDescriptionsConfig(
|
||||
llm=hydrate_llm_params(summarize_description_config, llm_model),
|
||||
parallelization=hydrate_parallelization_params(
|
||||
summarize_description_config, llm_parallelization_model
|
||||
),
|
||||
async_mode=hydrate_async_type(summarize_description_config, async_mode),
|
||||
prompt=reader.str("prompt", Fragment.prompt_file),
|
||||
max_length=reader.int(Fragment.max_length)
|
||||
or defs.SUMMARIZE_DESCRIPTIONS_MAX_LENGTH,
|
||||
)
|
||||
|
||||
with reader.use(values.get("cluster_graph")):
|
||||
use_lcc = reader.bool("use_lcc")
|
||||
seed = reader.int("seed")
|
||||
cluster_graph_model = ClusterGraphConfig(
|
||||
max_cluster_size=reader.int("max_cluster_size")
|
||||
or defs.MAX_CLUSTER_SIZE,
|
||||
use_lcc=use_lcc if use_lcc is not None else defs.USE_LCC,
|
||||
seed=seed if seed is not None else defs.CLUSTER_GRAPH_SEED,
|
||||
)
|
||||
|
||||
with (
|
||||
reader.use(values.get("local_search")),
|
||||
reader.envvar_prefix(Section.local_search),
|
||||
):
|
||||
local_search_model = LocalSearchConfig(
|
||||
prompt=reader.str("prompt") or None,
|
||||
text_unit_prop=reader.float("text_unit_prop")
|
||||
or defs.LOCAL_SEARCH_TEXT_UNIT_PROP,
|
||||
community_prop=reader.float("community_prop")
|
||||
or defs.LOCAL_SEARCH_COMMUNITY_PROP,
|
||||
conversation_history_max_turns=reader.int(
|
||||
"conversation_history_max_turns"
|
||||
)
|
||||
or defs.LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS,
|
||||
top_k_entities=reader.int("top_k_entities")
|
||||
or defs.LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES,
|
||||
top_k_relationships=reader.int("top_k_relationships")
|
||||
or defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS,
|
||||
temperature=reader.float("llm_temperature")
|
||||
or defs.LOCAL_SEARCH_LLM_TEMPERATURE,
|
||||
top_p=reader.float("llm_top_p") or defs.LOCAL_SEARCH_LLM_TOP_P,
|
||||
n=reader.int("llm_n") or defs.LOCAL_SEARCH_LLM_N,
|
||||
max_tokens=reader.int(Fragment.max_tokens)
|
||||
or defs.LOCAL_SEARCH_MAX_TOKENS,
|
||||
llm_max_tokens=reader.int("llm_max_tokens")
|
||||
or defs.LOCAL_SEARCH_LLM_MAX_TOKENS,
|
||||
)
|
||||
|
||||
with (
|
||||
reader.use(values.get("global_search")),
|
||||
reader.envvar_prefix(Section.global_search),
|
||||
):
|
||||
global_search_model = GlobalSearchConfig(
|
||||
map_prompt=reader.str("map_prompt") or None,
|
||||
reduce_prompt=reader.str("reduce_prompt") or None,
|
||||
knowledge_prompt=reader.str("knowledge_prompt") or None,
|
||||
temperature=reader.float("llm_temperature")
|
||||
or defs.GLOBAL_SEARCH_LLM_TEMPERATURE,
|
||||
top_p=reader.float("llm_top_p") or defs.GLOBAL_SEARCH_LLM_TOP_P,
|
||||
n=reader.int("llm_n") or defs.GLOBAL_SEARCH_LLM_N,
|
||||
max_tokens=reader.int(Fragment.max_tokens)
|
||||
or defs.GLOBAL_SEARCH_MAX_TOKENS,
|
||||
data_max_tokens=reader.int("data_max_tokens")
|
||||
or defs.GLOBAL_SEARCH_DATA_MAX_TOKENS,
|
||||
map_max_tokens=reader.int("map_max_tokens")
|
||||
or defs.GLOBAL_SEARCH_MAP_MAX_TOKENS,
|
||||
reduce_max_tokens=reader.int("reduce_max_tokens")
|
||||
or defs.GLOBAL_SEARCH_REDUCE_MAX_TOKENS,
|
||||
concurrency=reader.int("concurrency") or defs.GLOBAL_SEARCH_CONCURRENCY,
|
||||
)
|
||||
|
||||
with (
|
||||
reader.use(values.get("drift_search")),
|
||||
reader.envvar_prefix(Section.drift_search),
|
||||
):
|
||||
drift_search_model = DRIFTSearchConfig(
|
||||
prompt=reader.str("prompt") or None,
|
||||
reduce_prompt=reader.str("reduce_prompt") or None,
|
||||
temperature=reader.float("llm_temperature")
|
||||
or defs.DRIFT_SEARCH_LLM_TEMPERATURE,
|
||||
top_p=reader.float("llm_top_p") or defs.DRIFT_SEARCH_LLM_TOP_P,
|
||||
n=reader.int("llm_n") or defs.DRIFT_SEARCH_LLM_N,
|
||||
max_tokens=reader.int(Fragment.max_tokens)
|
||||
or defs.DRIFT_SEARCH_MAX_TOKENS,
|
||||
data_max_tokens=reader.int("data_max_tokens")
|
||||
or defs.DRIFT_SEARCH_DATA_MAX_TOKENS,
|
||||
reduce_max_tokens=reader.int("reduce_max_tokens")
|
||||
or defs.DRIFT_SEARCH_REDUCE_MAX_TOKENS,
|
||||
reduce_temperature=reader.float("reduce_temperature")
|
||||
or defs.DRIFT_SEARCH_REDUCE_LLM_TEMPERATURE,
|
||||
concurrency=reader.int("concurrency") or defs.DRIFT_SEARCH_CONCURRENCY,
|
||||
drift_k_followups=reader.int("drift_k_followups")
|
||||
or defs.DRIFT_SEARCH_K_FOLLOW_UPS,
|
||||
primer_folds=reader.int("primer_folds")
|
||||
or defs.DRIFT_SEARCH_PRIMER_FOLDS,
|
||||
primer_llm_max_tokens=reader.int("primer_llm_max_tokens")
|
||||
or defs.DRIFT_SEARCH_PRIMER_MAX_TOKENS,
|
||||
n_depth=reader.int("n_depth") or defs.DRIFT_N_DEPTH,
|
||||
local_search_text_unit_prop=reader.float("local_search_text_unit_prop")
|
||||
or defs.DRIFT_LOCAL_SEARCH_TEXT_UNIT_PROP,
|
||||
local_search_community_prop=reader.float("local_search_community_prop")
|
||||
or defs.DRIFT_LOCAL_SEARCH_COMMUNITY_PROP,
|
||||
local_search_top_k_mapped_entities=reader.int(
|
||||
"local_search_top_k_mapped_entities"
|
||||
)
|
||||
or defs.DRIFT_LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES,
|
||||
local_search_top_k_relationships=reader.int(
|
||||
"local_search_top_k_relationships"
|
||||
)
|
||||
or defs.DRIFT_LOCAL_SEARCH_TOP_K_RELATIONSHIPS,
|
||||
local_search_max_data_tokens=reader.int("local_search_max_data_tokens")
|
||||
or defs.DRIFT_LOCAL_SEARCH_MAX_TOKENS,
|
||||
local_search_temperature=reader.float("local_search_temperature")
|
||||
or defs.DRIFT_LOCAL_SEARCH_LLM_TEMPERATURE,
|
||||
local_search_top_p=reader.float("local_search_top_p")
|
||||
or defs.DRIFT_LOCAL_SEARCH_LLM_TOP_P,
|
||||
local_search_n=reader.int("local_search_n")
|
||||
or defs.DRIFT_LOCAL_SEARCH_LLM_N,
|
||||
local_search_llm_max_gen_tokens=reader.int(
|
||||
"local_search_llm_max_gen_tokens"
|
||||
)
|
||||
or defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS,
|
||||
)
|
||||
|
||||
with (
|
||||
reader.use(values.get("basic_search")),
|
||||
reader.envvar_prefix(Section.basic_search),
|
||||
):
|
||||
basic_search_model = BasicSearchConfig(
|
||||
prompt=reader.str("prompt") or None,
|
||||
text_unit_prop=reader.float("text_unit_prop")
|
||||
or defs.BASIC_SEARCH_TEXT_UNIT_PROP,
|
||||
conversation_history_max_turns=reader.int(
|
||||
"conversation_history_max_turns"
|
||||
)
|
||||
or defs.BASIC_SEARCH_CONVERSATION_HISTORY_MAX_TURNS,
|
||||
temperature=reader.float("llm_temperature")
|
||||
or defs.BASIC_SEARCH_LLM_TEMPERATURE,
|
||||
top_p=reader.float("llm_top_p") or defs.BASIC_SEARCH_LLM_TOP_P,
|
||||
n=reader.int("llm_n") or defs.BASIC_SEARCH_LLM_N,
|
||||
max_tokens=reader.int(Fragment.max_tokens)
|
||||
or defs.BASIC_SEARCH_MAX_TOKENS,
|
||||
llm_max_tokens=reader.int("llm_max_tokens")
|
||||
or defs.BASIC_SEARCH_LLM_MAX_TOKENS,
|
||||
)
|
||||
|
||||
skip_workflows = reader.list("skip_workflows") or []
|
||||
|
||||
return GraphRagConfig(
|
||||
root_dir=root_dir,
|
||||
llm=llm_model,
|
||||
parallelization=llm_parallelization_model,
|
||||
async_mode=async_mode,
|
||||
embeddings=embeddings_model,
|
||||
embed_graph=embed_graph_model,
|
||||
reporting=reporting_model,
|
||||
storage=storage_model,
|
||||
update_index_storage=update_index_storage_model,
|
||||
cache=cache_model,
|
||||
input=input_model,
|
||||
chunks=chunks_model,
|
||||
snapshots=snapshots_model,
|
||||
entity_extraction=entity_extraction_model,
|
||||
claim_extraction=claim_extraction_model,
|
||||
community_reports=community_reports_model,
|
||||
summarize_descriptions=summarize_descriptions_model,
|
||||
umap=umap_model,
|
||||
cluster_graph=cluster_graph_model,
|
||||
encoding_model=global_encoding_model,
|
||||
skip_workflows=skip_workflows,
|
||||
local_search=local_search_model,
|
||||
global_search=global_search_model,
|
||||
drift_search=drift_search_model,
|
||||
basic_search=basic_search_model,
|
||||
)
|
||||
|
||||
|
||||
class Fragment(str, Enum):
|
||||
"""Configuration Fragments."""
|
||||
|
||||
api_base = "API_BASE"
|
||||
api_key = "API_KEY"
|
||||
api_version = "API_VERSION"
|
||||
api_organization = "API_ORGANIZATION"
|
||||
api_proxy = "API_PROXY"
|
||||
async_mode = "ASYNC_MODE"
|
||||
audience = "AUDIENCE"
|
||||
base_dir = "BASE_DIR"
|
||||
concurrent_requests = "CONCURRENT_REQUESTS"
|
||||
conn_string = "CONNECTION_STRING"
|
||||
container_name = "CONTAINER_NAME"
|
||||
cosmosdb_account_url = "COSMOSDB_ACCOUNT_URL"
|
||||
deployment_name = "DEPLOYMENT_NAME"
|
||||
description = "DESCRIPTION"
|
||||
enabled = "ENABLED"
|
||||
encoding = "ENCODING"
|
||||
encoding_model = "ENCODING_MODEL"
|
||||
file_type = "FILE_TYPE"
|
||||
max_gleanings = "MAX_GLEANINGS"
|
||||
max_length = "MAX_LENGTH"
|
||||
max_retries = "MAX_RETRIES"
|
||||
max_retry_wait = "MAX_RETRY_WAIT"
|
||||
max_tokens = "MAX_TOKENS"
|
||||
temperature = "TEMPERATURE"
|
||||
top_p = "TOP_P"
|
||||
n = "N"
|
||||
model = "MODEL"
|
||||
model_supports_json = "MODEL_SUPPORTS_JSON"
|
||||
prompt_file = "PROMPT_FILE"
|
||||
request_timeout = "REQUEST_TIMEOUT"
|
||||
rpm = "REQUESTS_PER_MINUTE"
|
||||
sleep_recommendation = "SLEEP_ON_RATE_LIMIT_RECOMMENDATION"
|
||||
storage_account_blob_url = "STORAGE_ACCOUNT_BLOB_URL"
|
||||
thread_count = "THREAD_COUNT"
|
||||
thread_stagger = "THREAD_STAGGER"
|
||||
tpm = "TOKENS_PER_MINUTE"
|
||||
type = "TYPE"
|
||||
|
||||
|
||||
class Section(str, Enum):
|
||||
"""Configuration Sections."""
|
||||
|
||||
base = "BASE"
|
||||
cache = "CACHE"
|
||||
chunk = "CHUNK"
|
||||
claim_extraction = "CLAIM_EXTRACTION"
|
||||
community_reports = "COMMUNITY_REPORTS"
|
||||
embedding = "EMBEDDING"
|
||||
entity_extraction = "ENTITY_EXTRACTION"
|
||||
graphrag = "GRAPHRAG"
|
||||
input = "INPUT"
|
||||
llm = "LLM"
|
||||
node2vec = "NODE2VEC"
|
||||
reporting = "REPORTING"
|
||||
snapshot = "SNAPSHOT"
|
||||
storage = "STORAGE"
|
||||
summarize_descriptions = "SUMMARIZE_DESCRIPTIONS"
|
||||
umap = "UMAP"
|
||||
update_index_storage = "UPDATE_INDEX_STORAGE"
|
||||
local_search = "LOCAL_SEARCH"
|
||||
global_search = "GLOBAL_SEARCH"
|
||||
drift_search = "DRIFT_SEARCH"
|
||||
basic_search = "BASIC_SEARCH"
|
||||
|
||||
|
||||
def _is_azure(llm_type: LLMType | None) -> bool:
|
||||
return (
|
||||
llm_type == LLMType.AzureOpenAIChat or llm_type == LLMType.AzureOpenAIEmbedding
|
||||
)
|
||||
|
||||
|
||||
def _make_env(root_dir: str) -> Env:
|
||||
read_dotenv(root_dir)
|
||||
env = Env(expand_vars=True)
|
||||
env.read_env()
|
||||
return env
|
||||
|
||||
|
||||
def _token_replace(data: dict):
|
||||
"""Replace env-var tokens in a dictionary object."""
|
||||
for key, value in data.items():
|
||||
if isinstance(value, dict):
|
||||
_token_replace(value)
|
||||
elif isinstance(value, str):
|
||||
data[key] = os.path.expandvars(value)
|
||||
if root_dir:
|
||||
root_path = Path(root_dir).resolve()
|
||||
values["root_dir"] = str(root_path)
|
||||
return GraphRagConfig(**values)
|
||||
|
||||
@ -8,15 +8,18 @@ from pathlib import Path
|
||||
from graphrag.config.enums import (
|
||||
AsyncType,
|
||||
CacheType,
|
||||
ChunkStrategyType,
|
||||
InputFileType,
|
||||
InputType,
|
||||
LLMType,
|
||||
OutputType,
|
||||
ReportingType,
|
||||
StorageType,
|
||||
TextEmbeddingTarget,
|
||||
)
|
||||
from graphrag.vector_stores.factory import VectorStoreType
|
||||
|
||||
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
|
||||
DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model"
|
||||
ASYNC_MODE = AsyncType.Threaded
|
||||
ENCODING_MODEL = "cl100k_base"
|
||||
AZURE_AUDIENCE = "https://cognitiveservices.azure.com/.default"
|
||||
@ -47,24 +50,29 @@ EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
EMBEDDING_BATCH_SIZE = 16
|
||||
EMBEDDING_BATCH_MAX_TOKENS = 8191
|
||||
EMBEDDING_TARGET = TextEmbeddingTarget.required
|
||||
EMBEDDING_MODEL_ID = DEFAULT_EMBEDDING_MODEL_ID
|
||||
|
||||
CACHE_TYPE = CacheType.file
|
||||
CACHE_BASE_DIR = "cache"
|
||||
CHUNK_SIZE = 1200
|
||||
CHUNK_OVERLAP = 100
|
||||
CHUNK_GROUP_BY_COLUMNS = ["id"]
|
||||
CHUNK_STRATEGY = ChunkStrategyType.tokens
|
||||
CLAIM_DESCRIPTION = (
|
||||
"Any claims or facts that could be relevant to information discovery."
|
||||
)
|
||||
CLAIM_MAX_GLEANINGS = 1
|
||||
CLAIM_EXTRACTION_ENABLED = False
|
||||
CLAIM_EXTRACTION_MODEL_ID = DEFAULT_CHAT_MODEL_ID
|
||||
MAX_CLUSTER_SIZE = 10
|
||||
USE_LCC = True
|
||||
CLUSTER_GRAPH_SEED = 0xDEADBEEF
|
||||
COMMUNITY_REPORT_MAX_LENGTH = 2000
|
||||
COMMUNITY_REPORT_MAX_INPUT_LENGTH = 8000
|
||||
COMMUNITY_REPORT_MODEL_ID = DEFAULT_CHAT_MODEL_ID
|
||||
ENTITY_EXTRACTION_ENTITY_TYPES = ["organization", "person", "geo", "event"]
|
||||
ENTITY_EXTRACTION_MAX_GLEANINGS = 1
|
||||
ENTITY_EXTRACTION_MODEL_ID = DEFAULT_CHAT_MODEL_ID
|
||||
INPUT_FILE_TYPE = InputFileType.text
|
||||
INPUT_TYPE = InputType.file
|
||||
INPUT_BASE_DIR = "input"
|
||||
@ -86,25 +94,18 @@ REPORTING_BASE_DIR = "logs"
|
||||
SNAPSHOTS_GRAPHML = False
|
||||
SNAPSHOTS_EMBEDDINGS = False
|
||||
SNAPSHOTS_TRANSIENT = False
|
||||
STORAGE_BASE_DIR = "output"
|
||||
STORAGE_TYPE = StorageType.file
|
||||
OUTPUT_BASE_DIR = "output"
|
||||
OUTPUT_TYPE = OutputType.file
|
||||
SUMMARIZE_DESCRIPTIONS_MAX_LENGTH = 500
|
||||
SUMMARIZE_MODEL_ID = DEFAULT_CHAT_MODEL_ID
|
||||
UMAP_ENABLED = False
|
||||
UPDATE_STORAGE_BASE_DIR = "update_output"
|
||||
UPDATE_OUTPUT_BASE_DIR = "update_output"
|
||||
|
||||
VECTOR_STORE = f"""
|
||||
type: {VectorStoreType.LanceDB.value} # one of [lancedb, azure_ai_search, cosmosdb]
|
||||
db_uri: '{(Path(STORAGE_BASE_DIR) / "lancedb")!s}'
|
||||
collection_name: default
|
||||
overwrite: true\
|
||||
"""
|
||||
|
||||
VECTOR_STORE_DICT = {
|
||||
"type": VectorStoreType.LanceDB.value,
|
||||
"db_uri": str(Path(STORAGE_BASE_DIR) / "lancedb"),
|
||||
"collection_name": "default",
|
||||
"overwrite": True,
|
||||
}
|
||||
VECTOR_STORE_TYPE = VectorStoreType.LanceDB.value
|
||||
VECTOR_STORE_DB_URI = str(Path(OUTPUT_BASE_DIR) / "lancedb")
|
||||
VECTOR_STORE_CONTAINER_NAME = "default"
|
||||
VECTOR_STORE_OVERWRITE = True
|
||||
|
||||
# Local Search
|
||||
LOCAL_SEARCH_TEXT_UNIT_PROP = 0.5
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
|
||||
from graphrag.config.enums import TextEmbeddingTarget
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
|
||||
|
||||
entity_title_embedding = "entity.title"
|
||||
entity_description_embedding = "entity.description"
|
||||
@ -34,12 +33,14 @@ required_embeddings: set[str] = {
|
||||
|
||||
|
||||
def get_embedded_fields(settings: GraphRagConfig) -> set[str]:
|
||||
"""Get the fields to embed based on the enum or specifically skipped embeddings."""
|
||||
"""Get the fields to embed based on the enum or specifically selected embeddings."""
|
||||
match settings.embeddings.target:
|
||||
case TextEmbeddingTarget.all:
|
||||
return all_embeddings.difference(settings.embeddings.skip)
|
||||
return all_embeddings
|
||||
case TextEmbeddingTarget.required:
|
||||
return required_embeddings
|
||||
case TextEmbeddingTarget.selected:
|
||||
return set(settings.embeddings.names)
|
||||
case TextEmbeddingTarget.none:
|
||||
return set()
|
||||
case _:
|
||||
@ -48,24 +49,53 @@ def get_embedded_fields(settings: GraphRagConfig) -> set[str]:
|
||||
|
||||
|
||||
def get_embedding_settings(
|
||||
settings: TextEmbeddingConfig,
|
||||
settings: GraphRagConfig,
|
||||
vector_store_params: dict | None = None,
|
||||
) -> dict:
|
||||
"""Transform GraphRAG config into settings for workflows."""
|
||||
# TEMP
|
||||
vector_store_settings = settings.vector_store
|
||||
embeddings_llm_settings = settings.get_language_model_config(
|
||||
settings.embeddings.model_id
|
||||
)
|
||||
vector_store_settings = settings.vector_store.model_dump()
|
||||
if vector_store_settings is None:
|
||||
return {"strategy": settings.resolved_strategy()}
|
||||
return {
|
||||
"strategy": settings.embeddings.resolved_strategy(embeddings_llm_settings)
|
||||
}
|
||||
#
|
||||
# If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding.
|
||||
# settings.vector_store.base contains connection information, or may be undefined
|
||||
# settings.vector_store.<vector_name> contains the specific settings for this embedding
|
||||
#
|
||||
strategy = settings.resolved_strategy() # get the default strategy
|
||||
strategy = settings.embeddings.resolved_strategy(
|
||||
embeddings_llm_settings
|
||||
) # get the default strategy
|
||||
strategy.update({
|
||||
"vector_store": {**(vector_store_params or {}), **vector_store_settings}
|
||||
"vector_store": {
|
||||
**(vector_store_params or {}),
|
||||
**(vector_store_settings),
|
||||
}
|
||||
}) # update the default strategy with the vector store settings
|
||||
# This ensures the vector store config is part of the strategy and not the global config
|
||||
return {
|
||||
"strategy": strategy,
|
||||
}
|
||||
|
||||
|
||||
def create_collection_name(
|
||||
container_name: str, embedding_name: str, validate: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Create a collection name for the embedding store.
|
||||
|
||||
Within any given vector store, we can have multiple sets of embeddings organized into projects.
|
||||
The `container` param is used for this partitioning, and is added as a prefix to the collection name for differentiation.
|
||||
|
||||
The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings
|
||||
|
||||
Note that we use dot notation in our names, but many vector stores do not support this - so we convert to dashes.
|
||||
"""
|
||||
if validate and embedding_name not in all_embeddings:
|
||||
msg = f"Invalid embedding name: {embedding_name}"
|
||||
raise KeyError(msg)
|
||||
return f"{container_name}-{embedding_name}".replace(".", "-")
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'PipelineCacheConfig', 'PipelineFileCacheConfig' and 'PipelineMemoryCacheConfig' models."""
|
||||
"""A module containing config enums."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -53,17 +53,17 @@ class InputType(str, Enum):
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
class StorageType(str, Enum):
|
||||
"""The storage type for the pipeline."""
|
||||
class OutputType(str, Enum):
|
||||
"""The output type for the pipeline."""
|
||||
|
||||
file = "file"
|
||||
"""The file storage type."""
|
||||
"""The file output type."""
|
||||
memory = "memory"
|
||||
"""The memory storage type."""
|
||||
"""The memory output type."""
|
||||
blob = "blob"
|
||||
"""The blob storage type."""
|
||||
"""The blob output type."""
|
||||
cosmosdb = "cosmosdb"
|
||||
"""The cosmosdb storage type"""
|
||||
"""The cosmosdb output type"""
|
||||
|
||||
def __repr__(self):
|
||||
"""Get a string representation."""
|
||||
@ -90,6 +90,7 @@ class TextEmbeddingTarget(str, Enum):
|
||||
|
||||
all = "all"
|
||||
required = "required"
|
||||
selected = "selected"
|
||||
none = "none"
|
||||
|
||||
def __repr__(self):
|
||||
@ -116,8 +117,26 @@ class LLMType(str, Enum):
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
class AzureAuthType(str, Enum):
|
||||
"""AzureAuthType enum class definition."""
|
||||
|
||||
APIKey = "api_key"
|
||||
ManagedIdentity = "managed_identity"
|
||||
|
||||
|
||||
class AsyncType(str, Enum):
|
||||
"""Enum for the type of async to use."""
|
||||
|
||||
AsyncIO = "asyncio"
|
||||
Threaded = "threaded"
|
||||
|
||||
|
||||
class ChunkStrategyType(str, Enum):
|
||||
"""ChunkStrategy class definition."""
|
||||
|
||||
tokens = "tokens"
|
||||
sentence = "sentence"
|
||||
|
||||
def __repr__(self):
|
||||
"""Get a string representation."""
|
||||
return f'"{self.value}"'
|
||||
|
||||
@ -6,35 +6,54 @@
|
||||
class ApiKeyMissingError(ValueError):
|
||||
"""LLM Key missing error."""
|
||||
|
||||
def __init__(self, embedding: bool = False) -> None:
|
||||
def __init__(self, llm_type: str, azure_auth_type: str | None = None) -> None:
|
||||
"""Init method definition."""
|
||||
api_type = "Embedding" if embedding else "Completion"
|
||||
api_key = "GRAPHRAG_EMBEDDING_API_KEY" if embedding else "GRAPHRAG_LLM_API_KEY"
|
||||
msg = f"API Key is required for {api_type} API. Please set either the OPENAI_API_KEY, GRAPHRAG_API_KEY or {api_key} environment variable."
|
||||
msg = f"API Key is required for {llm_type}"
|
||||
if azure_auth_type:
|
||||
msg += f" when using {azure_auth_type} authentication"
|
||||
msg += ". Please rerun `graphrag init` and set the API_KEY."
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class AzureApiBaseMissingError(ValueError):
|
||||
"""Azure API Base missing error."""
|
||||
|
||||
def __init__(self, embedding: bool = False) -> None:
|
||||
def __init__(self, llm_type: str) -> None:
|
||||
"""Init method definition."""
|
||||
api_type = "Embedding" if embedding else "Completion"
|
||||
api_base = "GRAPHRAG_EMBEDDING_API_BASE" if embedding else "GRAPHRAG_API_BASE"
|
||||
msg = f"API Base is required for {api_type} API. Please set either the OPENAI_API_BASE, GRAPHRAG_API_BASE or {api_base} environment variable."
|
||||
msg = f"API Base is required for {llm_type}. Please rerun `graphrag init` and set the api_base."
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class AzureApiVersionMissingError(ValueError):
|
||||
"""Azure API version missing error."""
|
||||
|
||||
def __init__(self, llm_type: str) -> None:
|
||||
"""Init method definition."""
|
||||
msg = f"API Version is required for {llm_type}. Please rerun `graphrag init` and set the api_version."
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class AzureDeploymentNameMissingError(ValueError):
|
||||
"""Azure Deployment Name missing error."""
|
||||
|
||||
def __init__(self, embedding: bool = False) -> None:
|
||||
def __init__(self, llm_type: str) -> None:
|
||||
"""Init method definition."""
|
||||
msg = f"Deployment name is required for {llm_type}. Please rerun `graphrag init` set the deployment_name."
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class LanguageModelConfigMissingError(ValueError):
|
||||
"""Missing model configuration error."""
|
||||
|
||||
def __init__(self, key: str = "") -> None:
|
||||
"""Init method definition."""
|
||||
msg = f'A {key} model configuration is required. Please rerun `graphrag init` and set models["{key}"] in settings.yaml.'
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class ConflictingSettingsError(ValueError):
|
||||
"""Missing model configuration error."""
|
||||
|
||||
def __init__(self, msg: str) -> None:
|
||||
"""Init method definition."""
|
||||
api_type = "Embedding" if embedding else "Completion"
|
||||
api_base = (
|
||||
"GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME"
|
||||
if embedding
|
||||
else "GRAPHRAG_LLM_DEPLOYMENT_NAME"
|
||||
)
|
||||
msg = f"Deployment Name is required for {api_type} API. Please set either the OPENAI_DEPLOYMENT_NAME, GRAPHRAG_LLM_DEPLOYMENT_NAME or {api_base} environment variable."
|
||||
super().__init__(msg)
|
||||
|
||||
@ -12,38 +12,42 @@ INIT_YAML = f"""\
|
||||
### LLM settings ###
|
||||
## There are a number of settings to tune the threading and token limits for LLM calls - check the docs.
|
||||
|
||||
encoding_model: {defs.ENCODING_MODEL} # this needs to be matched to your model!
|
||||
|
||||
llm:
|
||||
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
|
||||
type: {defs.LLM_TYPE.value} # or azure_openai_chat
|
||||
model: {defs.LLM_MODEL}
|
||||
model_supports_json: true # recommended if this is available for your model.
|
||||
# audience: "https://cognitiveservices.azure.com/.default"
|
||||
# api_base: https://<instance>.openai.azure.com
|
||||
# api_version: 2024-02-15-preview
|
||||
# organization: <organization_id>
|
||||
# deployment_name: <azure_model_deployment_name>
|
||||
|
||||
parallelization:
|
||||
stagger: {defs.PARALLELIZATION_STAGGER}
|
||||
# num_threads: {defs.PARALLELIZATION_NUM_THREADS}
|
||||
|
||||
async_mode: {defs.ASYNC_MODE.value} # or asyncio
|
||||
|
||||
embeddings:
|
||||
async_mode: {defs.ASYNC_MODE.value} # or asyncio
|
||||
vector_store: {defs.VECTOR_STORE}
|
||||
llm:
|
||||
models:
|
||||
{defs.DEFAULT_CHAT_MODEL_ID}:
|
||||
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
|
||||
type: {defs.LLM_TYPE.value} # or azure_openai_chat
|
||||
model: {defs.LLM_MODEL}
|
||||
model_supports_json: true # recommended if this is available for your model.
|
||||
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
|
||||
parallelization_stagger: {defs.PARALLELIZATION_STAGGER}
|
||||
async_mode: {defs.ASYNC_MODE.value} # or asyncio
|
||||
# audience: "https://cognitiveservices.azure.com/.default"
|
||||
# api_base: https://<instance>.openai.azure.com
|
||||
# api_version: 2024-02-15-preview
|
||||
# organization: <organization_id>
|
||||
# deployment_name: <azure_model_deployment_name>
|
||||
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
|
||||
api_key: ${{GRAPHRAG_API_KEY}}
|
||||
type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding
|
||||
model: {defs.EMBEDDING_MODEL}
|
||||
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
|
||||
parallelization_stagger: {defs.PARALLELIZATION_STAGGER}
|
||||
async_mode: {defs.ASYNC_MODE.value} # or asyncio
|
||||
# api_base: https://<instance>.openai.azure.com
|
||||
# api_version: 2024-02-15-preview
|
||||
# audience: "https://cognitiveservices.azure.com/.default"
|
||||
# organization: <organization_id>
|
||||
# deployment_name: <azure_model_deployment_name>
|
||||
|
||||
vector_store:
|
||||
type: {defs.VECTOR_STORE_TYPE}
|
||||
db_uri: {defs.VECTOR_STORE_DB_URI}
|
||||
container_name: {defs.VECTOR_STORE_CONTAINER_NAME}
|
||||
overwrite: {defs.VECTOR_STORE_OVERWRITE}
|
||||
|
||||
embeddings:
|
||||
model_id: {defs.DEFAULT_EMBEDDING_MODEL_ID}
|
||||
|
||||
### Input settings ###
|
||||
|
||||
input:
|
||||
@ -51,14 +55,14 @@ input:
|
||||
file_type: {defs.INPUT_FILE_TYPE.value} # or csv
|
||||
base_dir: "{defs.INPUT_BASE_DIR}"
|
||||
file_encoding: {defs.INPUT_FILE_ENCODING}
|
||||
file_pattern: ".*\\\\.txt$"
|
||||
file_pattern: ".*\\\\.txt$$"
|
||||
|
||||
chunks:
|
||||
size: {defs.CHUNK_SIZE}
|
||||
overlap: {defs.CHUNK_OVERLAP}
|
||||
group_by_columns: [{",".join(defs.CHUNK_GROUP_BY_COLUMNS)}]
|
||||
|
||||
### Storage settings ###
|
||||
### Output settings ###
|
||||
## If blob storage is specified in the following four sections,
|
||||
## connection_string and container_name must be provided
|
||||
|
||||
@ -70,36 +74,38 @@ reporting:
|
||||
type: {defs.REPORTING_TYPE.value} # or console, blob
|
||||
base_dir: "{defs.REPORTING_BASE_DIR}"
|
||||
|
||||
storage:
|
||||
type: {defs.STORAGE_TYPE.value} # one of [blob, cosmosdb, file]
|
||||
base_dir: "{defs.STORAGE_BASE_DIR}"
|
||||
output:
|
||||
type: {defs.OUTPUT_TYPE.value} # one of [blob, cosmosdb, file]
|
||||
base_dir: "{defs.OUTPUT_BASE_DIR}"
|
||||
|
||||
## only turn this on if running `graphrag index` with custom settings
|
||||
## we normally use `graphrag update` with the defaults
|
||||
update_index_storage:
|
||||
# type: {defs.STORAGE_TYPE.value} # or blob
|
||||
# base_dir: "{defs.UPDATE_STORAGE_BASE_DIR}"
|
||||
update_index_output:
|
||||
# type: {defs.OUTPUT_TYPE.value} # or blob
|
||||
# base_dir: "{defs.UPDATE_OUTPUT_BASE_DIR}"
|
||||
|
||||
### Workflow settings ###
|
||||
|
||||
skip_workflows: []
|
||||
|
||||
entity_extraction:
|
||||
model_id: {defs.ENTITY_EXTRACTION_MODEL_ID}
|
||||
prompt: "prompts/entity_extraction.txt"
|
||||
entity_types: [{",".join(defs.ENTITY_EXTRACTION_ENTITY_TYPES)}]
|
||||
max_gleanings: {defs.ENTITY_EXTRACTION_MAX_GLEANINGS}
|
||||
|
||||
summarize_descriptions:
|
||||
model_id: {defs.SUMMARIZE_MODEL_ID}
|
||||
prompt: "prompts/summarize_descriptions.txt"
|
||||
max_length: {defs.SUMMARIZE_DESCRIPTIONS_MAX_LENGTH}
|
||||
|
||||
claim_extraction:
|
||||
enabled: false
|
||||
model_id: {defs.CLAIM_EXTRACTION_MODEL_ID}
|
||||
prompt: "prompts/claim_extraction.txt"
|
||||
description: "{defs.CLAIM_DESCRIPTION}"
|
||||
max_gleanings: {defs.CLAIM_MAX_GLEANINGS}
|
||||
|
||||
community_reports:
|
||||
model_id: {defs.COMMUNITY_REPORT_MODEL_ID}
|
||||
prompt: "prompts/community_report.txt"
|
||||
max_length: {defs.COMMUNITY_REPORT_MAX_LENGTH}
|
||||
max_input_length: {defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH}
|
||||
|
||||
@ -3,23 +3,86 @@
|
||||
|
||||
"""Default method for loading config."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from string import Template
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from graphrag.config.config_file_loader import (
|
||||
load_config_from_file,
|
||||
search_for_config_in_root_dir,
|
||||
)
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
|
||||
_default_config_files = ["settings.yaml", "settings.yml", "settings.json"]
|
||||
|
||||
def load_config(
|
||||
root_dir: Path,
|
||||
config_filepath: Path | None = None,
|
||||
) -> GraphRagConfig:
|
||||
"""Load configuration from a file or create a default configuration.
|
||||
|
||||
If a config file is not found the default configuration is created.
|
||||
def _search_for_config_in_root_dir(root: str | Path) -> Path | None:
|
||||
"""Resolve the config path from the given root directory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
root : str | Path
|
||||
The path to the root directory containing the config file.
|
||||
Searches for a default config file (settings.{yaml,yml,json}).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path | None
|
||||
returns a Path if there is a config in the root directory
|
||||
Otherwise returns None.
|
||||
"""
|
||||
root = Path(root)
|
||||
|
||||
if not root.is_dir():
|
||||
msg = f"Invalid config path: {root} is not a directory"
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
for file in _default_config_files:
|
||||
if (root / file).is_file():
|
||||
return root / file
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _parse_env_variables(text: str) -> str:
|
||||
"""Parse environment variables in the configuration text.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text : str
|
||||
The configuration text.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The configuration text with environment variables parsed.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If an environment variable is not found.
|
||||
"""
|
||||
return Template(text).substitute(os.environ)
|
||||
|
||||
|
||||
def _load_dotenv(config_path: Path | str) -> None:
|
||||
"""Load the .env file if it exists in the same directory as the config file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_path : Path | str
|
||||
The path to the config file.
|
||||
"""
|
||||
config_path = Path(config_path)
|
||||
dotenv_path = config_path.parent / ".env"
|
||||
if dotenv_path.exists():
|
||||
load_dotenv(dotenv_path)
|
||||
|
||||
|
||||
def _get_config_path(root_dir: Path, config_filepath: Path | None) -> Path:
|
||||
"""Get the configuration file path.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -27,23 +90,102 @@ def load_config(
|
||||
The root directory of the project. Will search for the config file in this directory.
|
||||
config_filepath : str | None
|
||||
The path to the config file.
|
||||
If None, searches for config file in root and
|
||||
if not found creates a default configuration.
|
||||
"""
|
||||
root = root_dir.resolve()
|
||||
If None, searches for config file in root.
|
||||
|
||||
# If user specified a config file path then it is required
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
The configuration file path.
|
||||
"""
|
||||
if config_filepath:
|
||||
config_path = config_filepath.resolve()
|
||||
if not config_path.exists():
|
||||
msg = f"Specified Config file not found: {config_path}"
|
||||
raise FileNotFoundError(msg)
|
||||
else:
|
||||
# resolve config filepath from the root directory if it exists
|
||||
config_path = search_for_config_in_root_dir(root)
|
||||
if config_path:
|
||||
config = load_config_from_file(config_path)
|
||||
else:
|
||||
config = create_graphrag_config(root_dir=str(root))
|
||||
config_path = _search_for_config_in_root_dir(root_dir)
|
||||
|
||||
return config
|
||||
if not config_path:
|
||||
msg = f"Config file not found in root directory: {root_dir}"
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return config_path
|
||||
|
||||
|
||||
def _apply_overrides(data: dict[str, Any], overrides: dict[str, Any]) -> None:
|
||||
"""Apply the overrides to the raw configuration."""
|
||||
for key, value in overrides.items():
|
||||
keys = key.split(".")
|
||||
target = data
|
||||
current_path = keys[0]
|
||||
for k in keys[:-1]:
|
||||
current_path += f".{k}"
|
||||
target_obj = target.get(k, {})
|
||||
if not isinstance(target_obj, dict):
|
||||
msg = f"Cannot override non-dict value: data[{current_path}] is not a dict."
|
||||
raise TypeError(msg)
|
||||
target[k] = target_obj
|
||||
target = target[k]
|
||||
target[keys[-1]] = value
|
||||
|
||||
|
||||
def _parse(file_extension: str, contents: str) -> dict[str, Any]:
|
||||
"""Parse configuration."""
|
||||
match file_extension:
|
||||
case ".yaml" | ".yml":
|
||||
return yaml.safe_load(contents)
|
||||
case ".json":
|
||||
return json.loads(contents)
|
||||
case _:
|
||||
msg = (
|
||||
f"Unable to parse config. Unsupported file extension: {file_extension}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def load_config(
|
||||
root_dir: Path,
|
||||
config_filepath: Path | None = None,
|
||||
cli_overrides: dict[str, Any] | None = None,
|
||||
) -> GraphRagConfig:
|
||||
"""Load configuration from a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
root_dir : str | Path
|
||||
The root directory of the project. Will search for the config file in this directory.
|
||||
config_filepath : str | None
|
||||
The path to the config file.
|
||||
If None, searches for config file in root.
|
||||
cli_overrides : dict[str, Any] | None
|
||||
A flat dictionary of cli overrides.
|
||||
Example: {'output.base_dir': 'override_value'}
|
||||
|
||||
Returns
|
||||
-------
|
||||
GraphRagConfig
|
||||
The loaded configuration.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file is not found.
|
||||
ValueError
|
||||
If the config file extension is not supported.
|
||||
TypeError
|
||||
If applying cli overrides to the config fails.
|
||||
KeyError
|
||||
If config file references a non-existent environment variable.
|
||||
ValidationError
|
||||
If there are pydantic validation errors when instantiating the config.
|
||||
"""
|
||||
root = root_dir.resolve()
|
||||
config_path = _get_config_path(root, config_filepath)
|
||||
_load_dotenv(config_path)
|
||||
config_extension = config_path.suffix
|
||||
config_text = config_path.read_text(encoding="utf-8")
|
||||
config_text = _parse_env_variables(config_text)
|
||||
config_data = _parse(config_extension, config_text)
|
||||
if cli_overrides:
|
||||
_apply_overrides(config_data, cli_overrides)
|
||||
return create_graphrag_config(config_data, root_dir=str(root))
|
||||
|
||||
@ -22,15 +22,15 @@ class BasicSearchConfig(BaseModel):
|
||||
description="The conversation history maximum turns.",
|
||||
default=defs.BASIC_SEARCH_CONVERSATION_HISTORY_MAX_TURNS,
|
||||
)
|
||||
temperature: float | None = Field(
|
||||
temperature: float = Field(
|
||||
description="The temperature to use for token generation.",
|
||||
default=defs.BASIC_SEARCH_LLM_TEMPERATURE,
|
||||
)
|
||||
top_p: float | None = Field(
|
||||
top_p: float = Field(
|
||||
description="The top-p value to use for token generation.",
|
||||
default=defs.BASIC_SEARCH_LLM_TOP_P,
|
||||
)
|
||||
n: int | None = Field(
|
||||
n: int = Field(
|
||||
description="The number of completions to generate.",
|
||||
default=defs.BASIC_SEARCH_LLM_N,
|
||||
)
|
||||
|
||||
@ -3,22 +3,10 @@
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
|
||||
|
||||
class ChunkStrategyType(str, Enum):
|
||||
"""ChunkStrategy class definition."""
|
||||
|
||||
tokens = "tokens"
|
||||
sentence = "sentence"
|
||||
|
||||
def __repr__(self):
|
||||
"""Get a string representation."""
|
||||
return f'"{self.value}"'
|
||||
from graphrag.config.enums import ChunkStrategyType
|
||||
|
||||
|
||||
class ChunkingConfig(BaseModel):
|
||||
@ -33,7 +21,7 @@ class ChunkingConfig(BaseModel):
|
||||
default=defs.CHUNK_GROUP_BY_COLUMNS,
|
||||
)
|
||||
strategy: ChunkStrategyType = Field(
|
||||
description="The chunking strategy to use.", default=ChunkStrategyType.tokens
|
||||
description="The chunking strategy to use.", default=defs.CHUNK_STRATEGY
|
||||
)
|
||||
encoding_model: str = Field(
|
||||
description="The encoding model to use.", default=defs.ENCODING_MODEL
|
||||
|
||||
@ -5,17 +5,18 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.models.llm_config import LLMConfig
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
|
||||
|
||||
class ClaimExtractionConfig(LLMConfig):
|
||||
class ClaimExtractionConfig(BaseModel):
|
||||
"""Configuration section for claim extraction."""
|
||||
|
||||
enabled: bool = Field(
|
||||
description="Whether claim extraction is enabled.",
|
||||
default=defs.CLAIM_EXTRACTION_ENABLED,
|
||||
)
|
||||
prompt: str | None = Field(
|
||||
description="The claim extraction prompt to use.", default=None
|
||||
@ -34,18 +35,25 @@ class ClaimExtractionConfig(LLMConfig):
|
||||
encoding_model: str | None = Field(
|
||||
default=None, description="The encoding model to use."
|
||||
)
|
||||
model_id: str = Field(
|
||||
description="The model ID to use for claim extraction.",
|
||||
default=defs.CLAIM_EXTRACTION_MODEL_ID,
|
||||
)
|
||||
|
||||
def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
|
||||
def resolved_strategy(
|
||||
self, root_dir: str, model_config: LanguageModelConfig
|
||||
) -> dict:
|
||||
"""Get the resolved claim extraction strategy."""
|
||||
return self.strategy or {
|
||||
"llm": self.llm.model_dump(),
|
||||
**self.parallelization.model_dump(),
|
||||
"extraction_prompt": (Path(root_dir) / self.prompt)
|
||||
.read_bytes()
|
||||
.decode(encoding="utf-8")
|
||||
"llm": model_config.model_dump(),
|
||||
"stagger": model_config.parallelization_stagger,
|
||||
"num_threads": model_config.parallelization_num_threads,
|
||||
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
|
||||
encoding="utf-8"
|
||||
)
|
||||
if self.prompt
|
||||
else None,
|
||||
"claim_description": self.description,
|
||||
"max_gleanings": self.max_gleanings,
|
||||
"encoding_name": encoding_model or self.encoding_model,
|
||||
"encoding_name": model_config.encoding_model,
|
||||
}
|
||||
|
||||
@ -18,7 +18,7 @@ class ClusterGraphConfig(BaseModel):
|
||||
description="Whether to use the largest connected component.",
|
||||
default=defs.USE_LCC,
|
||||
)
|
||||
seed: int | None = Field(
|
||||
seed: int = Field(
|
||||
description="The seed to use for the clustering.",
|
||||
default=defs.CLUSTER_GRAPH_SEED,
|
||||
)
|
||||
|
||||
@ -5,13 +5,13 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.models.llm_config import LLMConfig
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
|
||||
|
||||
class CommunityReportsConfig(LLMConfig):
|
||||
class CommunityReportsConfig(BaseModel):
|
||||
"""Configuration section for community reports."""
|
||||
|
||||
prompt: str | None = Field(
|
||||
@ -28,8 +28,14 @@ class CommunityReportsConfig(LLMConfig):
|
||||
strategy: dict | None = Field(
|
||||
description="The override strategy to use.", default=None
|
||||
)
|
||||
model_id: str = Field(
|
||||
description="The model ID to use for community reports.",
|
||||
default=defs.COMMUNITY_REPORT_MODEL_ID,
|
||||
)
|
||||
|
||||
def resolved_strategy(self, root_dir) -> dict:
|
||||
def resolved_strategy(
|
||||
self, root_dir: str, model_config: LanguageModelConfig
|
||||
) -> dict:
|
||||
"""Get the resolved community report extraction strategy."""
|
||||
from graphrag.index.operations.summarize_communities import (
|
||||
CreateCommunityReportsStrategyType,
|
||||
@ -37,11 +43,12 @@ class CommunityReportsConfig(LLMConfig):
|
||||
|
||||
return self.strategy or {
|
||||
"type": CreateCommunityReportsStrategyType.graph_intelligence,
|
||||
"llm": self.llm.model_dump(),
|
||||
**self.parallelization.model_dump(),
|
||||
"extraction_prompt": (Path(root_dir) / self.prompt)
|
||||
.read_bytes()
|
||||
.decode(encoding="utf-8")
|
||||
"llm": model_config.model_dump(),
|
||||
"stagger": model_config.parallelization_stagger,
|
||||
"num_threads": model_config.parallelization_num_threads,
|
||||
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
|
||||
encoding="utf-8"
|
||||
)
|
||||
if self.prompt
|
||||
else None,
|
||||
"max_report_length": self.max_length,
|
||||
|
||||
@ -5,13 +5,13 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.models.llm_config import LLMConfig
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
|
||||
|
||||
class EntityExtractionConfig(LLMConfig):
|
||||
class EntityExtractionConfig(BaseModel):
|
||||
"""Configuration section for entity extraction."""
|
||||
|
||||
prompt: str | None = Field(
|
||||
@ -31,8 +31,14 @@ class EntityExtractionConfig(LLMConfig):
|
||||
encoding_model: str | None = Field(
|
||||
default=None, description="The encoding model to use."
|
||||
)
|
||||
model_id: str = Field(
|
||||
description="The model ID to use for text embeddings.",
|
||||
default=defs.ENTITY_EXTRACTION_MODEL_ID,
|
||||
)
|
||||
|
||||
def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
|
||||
def resolved_strategy(
|
||||
self, root_dir: str, model_config: LanguageModelConfig
|
||||
) -> dict:
|
||||
"""Get the resolved entity extraction strategy."""
|
||||
from graphrag.index.operations.extract_entities import (
|
||||
ExtractEntityStrategyType,
|
||||
@ -40,13 +46,14 @@ class EntityExtractionConfig(LLMConfig):
|
||||
|
||||
return self.strategy or {
|
||||
"type": ExtractEntityStrategyType.graph_intelligence,
|
||||
"llm": self.llm.model_dump(),
|
||||
**self.parallelization.model_dump(),
|
||||
"extraction_prompt": (Path(root_dir) / self.prompt)
|
||||
.read_bytes()
|
||||
.decode(encoding="utf-8")
|
||||
"llm": model_config.model_dump(),
|
||||
"stagger": model_config.parallelization_stagger,
|
||||
"num_threads": model_config.parallelization_num_threads,
|
||||
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
|
||||
encoding="utf-8"
|
||||
)
|
||||
if self.prompt
|
||||
else None,
|
||||
"max_gleanings": self.max_gleanings,
|
||||
"encoding_name": encoding_model or self.encoding_model,
|
||||
"encoding_name": model_config.encoding_model,
|
||||
}
|
||||
|
||||
@ -20,15 +20,15 @@ class GlobalSearchConfig(BaseModel):
|
||||
knowledge_prompt: str | None = Field(
|
||||
description="The global search general prompt to use.", default=None
|
||||
)
|
||||
temperature: float | None = Field(
|
||||
temperature: float = Field(
|
||||
description="The temperature to use for token generation.",
|
||||
default=defs.GLOBAL_SEARCH_LLM_TEMPERATURE,
|
||||
)
|
||||
top_p: float | None = Field(
|
||||
top_p: float = Field(
|
||||
description="The top-p value to use for token generation.",
|
||||
default=defs.GLOBAL_SEARCH_LLM_TOP_P,
|
||||
)
|
||||
n: int | None = Field(
|
||||
n: int = Field(
|
||||
description="The number of completions to generate.",
|
||||
default=defs.GLOBAL_SEARCH_LLM_N,
|
||||
)
|
||||
|
||||
@ -3,10 +3,13 @@
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from devtools import pformat
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.errors import LanguageModelConfigMissingError
|
||||
from graphrag.config.models.basic_search_config import BasicSearchConfig
|
||||
from graphrag.config.models.cache_config import CacheConfig
|
||||
from graphrag.config.models.chunking_config import ChunkingConfig
|
||||
@ -18,19 +21,21 @@ from graphrag.config.models.embed_graph_config import EmbedGraphConfig
|
||||
from graphrag.config.models.entity_extraction_config import EntityExtractionConfig
|
||||
from graphrag.config.models.global_search_config import GlobalSearchConfig
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.config.models.llm_config import LLMConfig
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.config.models.local_search_config import LocalSearchConfig
|
||||
from graphrag.config.models.output_config import OutputConfig
|
||||
from graphrag.config.models.reporting_config import ReportingConfig
|
||||
from graphrag.config.models.snapshots_config import SnapshotsConfig
|
||||
from graphrag.config.models.storage_config import StorageConfig
|
||||
from graphrag.config.models.summarize_descriptions_config import (
|
||||
SummarizeDescriptionsConfig,
|
||||
)
|
||||
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
|
||||
from graphrag.config.models.umap_config import UmapConfig
|
||||
from graphrag.config.models.vector_store_config import VectorStoreConfig
|
||||
from graphrag.vector_stores.factory import VectorStoreType
|
||||
|
||||
|
||||
class GraphRagConfig(LLMConfig):
|
||||
class GraphRagConfig(BaseModel):
|
||||
"""Base class for the Default-Configuration parameterization settings."""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -42,24 +47,91 @@ class GraphRagConfig(LLMConfig):
|
||||
return self.model_dump_json(indent=4)
|
||||
|
||||
root_dir: str = Field(
|
||||
description="The root directory for the configuration.", default="."
|
||||
description="The root directory for the configuration.", default=""
|
||||
)
|
||||
|
||||
def _validate_root_dir(self) -> None:
|
||||
"""Validate the root directory."""
|
||||
if self.root_dir.strip() == "":
|
||||
self.root_dir = str(Path.cwd())
|
||||
|
||||
root_dir = Path(self.root_dir).resolve()
|
||||
if not root_dir.is_dir():
|
||||
msg = f"Invalid root directory: {self.root_dir} is not a directory."
|
||||
raise FileNotFoundError(msg)
|
||||
self.root_dir = str(root_dir)
|
||||
|
||||
models: dict[str, LanguageModelConfig] = Field(
|
||||
description="Available language model configurations.",
|
||||
default={},
|
||||
)
|
||||
|
||||
def _validate_models(self) -> None:
|
||||
"""Validate the models configuration.
|
||||
|
||||
Ensure both a default chat model and default embedding model
|
||||
have been defined. Other models may also be defined but
|
||||
defaults are required for the time being as places of the
|
||||
code fallback to default model configs instead
|
||||
of specifying a specific model.
|
||||
|
||||
TODO: Don't fallback to default models elsewhere in the code.
|
||||
Forcing code to specify a model to use and allowing for any
|
||||
names for model configurations.
|
||||
"""
|
||||
if defs.DEFAULT_CHAT_MODEL_ID not in self.models:
|
||||
raise LanguageModelConfigMissingError(defs.DEFAULT_CHAT_MODEL_ID)
|
||||
if defs.DEFAULT_EMBEDDING_MODEL_ID not in self.models:
|
||||
raise LanguageModelConfigMissingError(defs.DEFAULT_EMBEDDING_MODEL_ID)
|
||||
|
||||
reporting: ReportingConfig = Field(
|
||||
description="The reporting configuration.", default=ReportingConfig()
|
||||
)
|
||||
"""The reporting configuration."""
|
||||
|
||||
storage: StorageConfig = Field(
|
||||
description="The storage configuration.", default=StorageConfig()
|
||||
)
|
||||
"""The storage configuration."""
|
||||
def _validate_reporting_base_dir(self) -> None:
|
||||
"""Validate the reporting base directory."""
|
||||
if self.reporting.type == defs.ReportingType.file:
|
||||
if self.reporting.base_dir.strip() == "":
|
||||
msg = "Reporting base directory is required for file reporting. Please rerun `graphrag init` and set the reporting configuration."
|
||||
raise ValueError(msg)
|
||||
self.reporting.base_dir = str(
|
||||
(Path(self.root_dir) / self.reporting.base_dir).resolve()
|
||||
)
|
||||
|
||||
update_index_storage: StorageConfig | None = Field(
|
||||
description="The storage configuration for the updated index.",
|
||||
output: OutputConfig = Field(
|
||||
description="The output configuration.", default=OutputConfig()
|
||||
)
|
||||
"""The output configuration."""
|
||||
|
||||
def _validate_output_base_dir(self) -> None:
|
||||
"""Validate the output base directory."""
|
||||
if self.output.type == defs.OutputType.file:
|
||||
if self.output.base_dir.strip() == "":
|
||||
msg = "output base directory is required for file output. Please rerun `graphrag init` and set the output configuration."
|
||||
raise ValueError(msg)
|
||||
self.output.base_dir = str(
|
||||
(Path(self.root_dir) / self.output.base_dir).resolve()
|
||||
)
|
||||
|
||||
update_index_output: OutputConfig | None = Field(
|
||||
description="The output configuration for the updated index.",
|
||||
default=None,
|
||||
)
|
||||
"""The storage configuration for the updated index."""
|
||||
"""The output configuration for the updated index."""
|
||||
|
||||
def _validate_update_index_output_base_dir(self) -> None:
|
||||
"""Validate the update index output base directory."""
|
||||
if (
|
||||
self.update_index_output
|
||||
and self.update_index_output.type == defs.OutputType.file
|
||||
):
|
||||
if self.update_index_output.base_dir.strip() == "":
|
||||
msg = "Update index output base directory is required for file output. Please rerun `graphrag init` and set the update index output configuration."
|
||||
raise ValueError(msg)
|
||||
self.update_index_output.base_dir = str(
|
||||
(Path(self.root_dir) / self.update_index_output.base_dir).resolve()
|
||||
)
|
||||
|
||||
cache: CacheConfig = Field(
|
||||
description="The cache configuration.", default=CacheConfig()
|
||||
@ -152,12 +224,52 @@ class GraphRagConfig(LLMConfig):
|
||||
)
|
||||
"""The basic search configuration."""
|
||||
|
||||
encoding_model: str = Field(
|
||||
description="The encoding model to use.", default=defs.ENCODING_MODEL
|
||||
vector_store: VectorStoreConfig = Field(
|
||||
description="The vector store configuration.", default=VectorStoreConfig()
|
||||
)
|
||||
"""The encoding model to use."""
|
||||
"""The vector store configuration."""
|
||||
|
||||
skip_workflows: list[str] = Field(
|
||||
description="The workflows to skip, usually for testing reasons.", default=[]
|
||||
)
|
||||
"""The workflows to skip, usually for testing reasons."""
|
||||
def _validate_vector_store_db_uri(self) -> None:
|
||||
"""Validate the vector store configuration."""
|
||||
if self.vector_store.type == VectorStoreType.LanceDB:
|
||||
if not self.vector_store.db_uri or self.vector_store.db_uri.strip == "":
|
||||
msg = "Vector store URI is required for LanceDB. Please rerun `graphrag init` and set the vector store configuration."
|
||||
raise ValueError(msg)
|
||||
self.vector_store.db_uri = str(
|
||||
(Path(self.root_dir) / self.vector_store.db_uri).resolve()
|
||||
)
|
||||
|
||||
def get_language_model_config(self, model_id: str) -> LanguageModelConfig:
|
||||
"""Get a model configuration by ID.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_id : str
|
||||
The ID of the model to get. Should match an ID in the models list.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanguageModelConfig
|
||||
The model configuration if found.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the model ID is not found in the configuration.
|
||||
"""
|
||||
if model_id not in self.models:
|
||||
err_msg = f"Model ID {model_id} not found in configuration. Please rerun `graphrag init` and set the model configuration."
|
||||
raise ValueError(err_msg)
|
||||
|
||||
return self.models[model_id]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_model(self):
|
||||
"""Validate the model configuration."""
|
||||
self._validate_root_dir()
|
||||
self._validate_models()
|
||||
self._validate_reporting_base_dir()
|
||||
self._validate_output_base_dir()
|
||||
self._validate_update_index_output_base_dir()
|
||||
self._validate_vector_store_db_uri()
|
||||
return self
|
||||
|
||||
@ -30,7 +30,7 @@ class InputConfig(BaseModel):
|
||||
container_name: str | None = Field(
|
||||
description="The azure blob storage container name to use.", default=None
|
||||
)
|
||||
encoding: str | None = Field(
|
||||
encoding: str = Field(
|
||||
description="The input file encoding to use.",
|
||||
default=defs.INPUT_FILE_ENCODING,
|
||||
)
|
||||
|
||||
239
graphrag/config/models/language_model_config.py
Normal file
239
graphrag/config/models/language_model_config.py
Normal file
@ -0,0 +1,239 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Language model configuration."""
|
||||
|
||||
import tiktoken
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.enums import AsyncType, AzureAuthType, LLMType
|
||||
from graphrag.config.errors import (
|
||||
ApiKeyMissingError,
|
||||
AzureApiBaseMissingError,
|
||||
AzureApiVersionMissingError,
|
||||
AzureDeploymentNameMissingError,
|
||||
ConflictingSettingsError,
|
||||
)
|
||||
|
||||
|
||||
class LanguageModelConfig(BaseModel):
|
||||
"""Language model configuration."""
|
||||
|
||||
api_key: str | None = Field(
|
||||
description="The API key to use for the LLM service.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def _validate_api_key(self) -> None:
|
||||
"""Validate the API key.
|
||||
|
||||
API Key is required when using OpenAI API
|
||||
or when using Azure API with API Key authentication.
|
||||
For the time being, this check is extra verbose for clarity.
|
||||
It will also through an exception if an API Key is provided
|
||||
when one is not expected such as the case of using Azure
|
||||
Managed Identity.
|
||||
|
||||
Raises
|
||||
------
|
||||
ApiKeyMissingError
|
||||
If the API key is missing and is required.
|
||||
"""
|
||||
if (
|
||||
self.type == LLMType.OpenAIEmbedding
|
||||
or self.type == LLMType.OpenAIChat
|
||||
or self.azure_auth_type == AzureAuthType.APIKey
|
||||
) and (self.api_key is None or self.api_key.strip() == ""):
|
||||
raise ApiKeyMissingError(
|
||||
self.type.value,
|
||||
self.azure_auth_type.value if self.azure_auth_type else None,
|
||||
)
|
||||
|
||||
if (self.azure_auth_type == AzureAuthType.ManagedIdentity) and (
|
||||
self.api_key is not None and self.api_key.strip() != ""
|
||||
):
|
||||
msg = "API Key should not be provided when using Azure Managed Identity. Please rerun `graphrag init` and remove the api_key when using Azure Managed Identity."
|
||||
raise ConflictingSettingsError(msg)
|
||||
|
||||
azure_auth_type: AzureAuthType | None = Field(
|
||||
description="The Azure authentication type to use when using AOI.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
type: LLMType = Field(description="The type of LLM model to use.")
|
||||
model: str = Field(description="The LLM model to use.")
|
||||
encoding_model: str = Field(description="The encoding model to use", default="")
|
||||
|
||||
def _validate_encoding_model(self) -> None:
|
||||
"""Validate the encoding model.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If the model name is not recognized.
|
||||
"""
|
||||
if self.encoding_model.strip() == "":
|
||||
self.encoding_model = tiktoken.encoding_name_for_model(self.model)
|
||||
|
||||
max_tokens: int = Field(
|
||||
description="The maximum number of tokens to generate.",
|
||||
default=defs.LLM_MAX_TOKENS,
|
||||
)
|
||||
temperature: float = Field(
|
||||
description="The temperature to use for token generation.",
|
||||
default=defs.LLM_TEMPERATURE,
|
||||
)
|
||||
top_p: float = Field(
|
||||
description="The top-p value to use for token generation.",
|
||||
default=defs.LLM_TOP_P,
|
||||
)
|
||||
n: int = Field(
|
||||
description="The number of completions to generate.",
|
||||
default=defs.LLM_N,
|
||||
)
|
||||
frequency_penalty: float = Field(
|
||||
description="The frequency penalty to use for token generation.",
|
||||
default=defs.LLM_FREQUENCY_PENALTY,
|
||||
)
|
||||
presence_penalty: float = Field(
|
||||
description="The presence penalty to use for token generation.",
|
||||
default=defs.LLM_PRESENCE_PENALTY,
|
||||
)
|
||||
request_timeout: float = Field(
|
||||
description="The request timeout to use.", default=defs.LLM_REQUEST_TIMEOUT
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
description="The base URL for the LLM API.", default=None
|
||||
)
|
||||
|
||||
def _validate_api_base(self) -> None:
|
||||
"""Validate the API base.
|
||||
|
||||
Required when using AOI.
|
||||
|
||||
Raises
|
||||
------
|
||||
AzureApiBaseMissingError
|
||||
If the API base is missing and is required.
|
||||
"""
|
||||
if (
|
||||
self.type == LLMType.AzureOpenAIChat
|
||||
or self.type == LLMType.AzureOpenAIEmbedding
|
||||
) and (self.api_base is None or self.api_base.strip() == ""):
|
||||
raise AzureApiBaseMissingError(self.type.value)
|
||||
|
||||
api_version: str | None = Field(
|
||||
description="The version of the LLM API to use.", default=None
|
||||
)
|
||||
|
||||
def _validate_api_version(self) -> None:
|
||||
"""Validate the API version.
|
||||
|
||||
Required when using AOI.
|
||||
|
||||
Raises
|
||||
------
|
||||
AzureApiBaseMissingError
|
||||
If the API base is missing and is required.
|
||||
"""
|
||||
if (
|
||||
self.type == LLMType.AzureOpenAIChat
|
||||
or self.type == LLMType.AzureOpenAIEmbedding
|
||||
) and (self.api_version is None or self.api_version.strip() == ""):
|
||||
raise AzureApiVersionMissingError(self.type.value)
|
||||
|
||||
deployment_name: str | None = Field(
|
||||
description="The deployment name to use for the LLM service.", default=None
|
||||
)
|
||||
|
||||
def _validate_deployment_name(self) -> None:
|
||||
"""Validate the deployment name.
|
||||
|
||||
Required when using AOI.
|
||||
|
||||
Raises
|
||||
------
|
||||
AzureDeploymentNameMissingError
|
||||
If the deployment name is missing and is required.
|
||||
"""
|
||||
if (
|
||||
self.type == LLMType.AzureOpenAIChat
|
||||
or self.type == LLMType.AzureOpenAIEmbedding
|
||||
) and (self.deployment_name is None or self.deployment_name.strip() == ""):
|
||||
raise AzureDeploymentNameMissingError(self.type.value)
|
||||
|
||||
organization: str | None = Field(
|
||||
description="The organization to use for the LLM service.", default=None
|
||||
)
|
||||
proxy: str | None = Field(
|
||||
description="The proxy to use for the LLM service.", default=None
|
||||
)
|
||||
audience: str | None = Field(
|
||||
description="Azure resource URI to use with managed identity for the llm connection.",
|
||||
default=None,
|
||||
)
|
||||
model_supports_json: bool | None = Field(
|
||||
description="Whether the model supports JSON output mode.", default=None
|
||||
)
|
||||
tokens_per_minute: int = Field(
|
||||
description="The number of tokens per minute to use for the LLM service.",
|
||||
default=defs.LLM_TOKENS_PER_MINUTE,
|
||||
)
|
||||
requests_per_minute: int = Field(
|
||||
description="The number of requests per minute to use for the LLM service.",
|
||||
default=defs.LLM_REQUESTS_PER_MINUTE,
|
||||
)
|
||||
max_retries: int = Field(
|
||||
description="The maximum number of retries to use for the LLM service.",
|
||||
default=defs.LLM_MAX_RETRIES,
|
||||
)
|
||||
max_retry_wait: float = Field(
|
||||
description="The maximum retry wait to use for the LLM service.",
|
||||
default=defs.LLM_MAX_RETRY_WAIT,
|
||||
)
|
||||
sleep_on_rate_limit_recommendation: bool = Field(
|
||||
description="Whether to sleep on rate limit recommendations.",
|
||||
default=defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION,
|
||||
)
|
||||
concurrent_requests: int = Field(
|
||||
description="Whether to use concurrent requests for the LLM service.",
|
||||
default=defs.LLM_CONCURRENT_REQUESTS,
|
||||
)
|
||||
responses: list[str | BaseModel] | None = Field(
|
||||
default=None, description="Static responses to use in mock mode."
|
||||
)
|
||||
parallelization_stagger: float = Field(
|
||||
description="The stagger to use for the LLM service.",
|
||||
default=defs.PARALLELIZATION_STAGGER,
|
||||
)
|
||||
parallelization_num_threads: int = Field(
|
||||
description="The number of threads to use for the LLM service.",
|
||||
default=defs.PARALLELIZATION_NUM_THREADS,
|
||||
)
|
||||
async_mode: AsyncType = Field(
|
||||
description="The async mode to use.", default=defs.ASYNC_MODE
|
||||
)
|
||||
|
||||
def _validate_azure_settings(self) -> None:
|
||||
"""Validate the Azure settings.
|
||||
|
||||
Raises
|
||||
------
|
||||
AzureApiBaseMissingError
|
||||
If the API base is missing and is required.
|
||||
AzureApiVersionMissingError
|
||||
If the API version is missing and is required.
|
||||
AzureDeploymentNameMissingError
|
||||
If the deployment name is missing and is required.
|
||||
"""
|
||||
self._validate_api_base()
|
||||
self._validate_api_version()
|
||||
self._validate_deployment_name()
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_model(self):
|
||||
self._validate_api_key()
|
||||
self._validate_azure_settings()
|
||||
self._validate_encoding_model()
|
||||
return self
|
||||
@ -1,26 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.config.models.llm_parameters import LLMParameters
|
||||
from graphrag.config.models.parallelization_parameters import ParallelizationParameters
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
"""Base class for LLM-configured steps."""
|
||||
|
||||
llm: LLMParameters = Field(
|
||||
description="The LLM configuration to use.", default=LLMParameters()
|
||||
)
|
||||
parallelization: ParallelizationParameters = Field(
|
||||
description="The parallelization configuration to use.",
|
||||
default=ParallelizationParameters(),
|
||||
)
|
||||
async_mode: AsyncType = Field(
|
||||
description="The async mode to use.", default=defs.ASYNC_MODE
|
||||
)
|
||||
@ -1,102 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LLM Parameters model."""
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.enums import LLMType
|
||||
|
||||
|
||||
class LLMParameters(BaseModel):
|
||||
"""LLM Parameters model."""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
||||
api_key: str | None = Field(
|
||||
description="The API key to use for the LLM service.",
|
||||
default=None,
|
||||
)
|
||||
type: LLMType = Field(
|
||||
description="The type of LLM model to use.", default=defs.LLM_TYPE
|
||||
)
|
||||
encoding_model: str | None = Field(
|
||||
description="The encoding model to use", default=defs.ENCODING_MODEL
|
||||
)
|
||||
model: str = Field(description="The LLM model to use.", default=defs.LLM_MODEL)
|
||||
max_tokens: int | None = Field(
|
||||
description="The maximum number of tokens to generate.",
|
||||
default=defs.LLM_MAX_TOKENS,
|
||||
)
|
||||
temperature: float | None = Field(
|
||||
description="The temperature to use for token generation.",
|
||||
default=defs.LLM_TEMPERATURE,
|
||||
)
|
||||
top_p: float | None = Field(
|
||||
description="The top-p value to use for token generation.",
|
||||
default=defs.LLM_TOP_P,
|
||||
)
|
||||
n: int | None = Field(
|
||||
description="The number of completions to generate.",
|
||||
default=defs.LLM_N,
|
||||
)
|
||||
frequency_penalty: float | None = Field(
|
||||
description="The frequency penalty to use for token generation.",
|
||||
default=defs.LLM_FREQUENCY_PENALTY,
|
||||
)
|
||||
presence_penalty: float | None = Field(
|
||||
description="The presence penalty to use for token generation.",
|
||||
default=defs.LLM_PRESENCE_PENALTY,
|
||||
)
|
||||
request_timeout: float = Field(
|
||||
description="The request timeout to use.", default=defs.LLM_REQUEST_TIMEOUT
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
description="The base URL for the LLM API.", default=None
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
description="The version of the LLM API to use.", default=None
|
||||
)
|
||||
organization: str | None = Field(
|
||||
description="The organization to use for the LLM service.", default=None
|
||||
)
|
||||
proxy: str | None = Field(
|
||||
description="The proxy to use for the LLM service.", default=None
|
||||
)
|
||||
audience: str | None = Field(
|
||||
description="Azure resource URI to use with managed identity for the llm connection.",
|
||||
default=None,
|
||||
)
|
||||
deployment_name: str | None = Field(
|
||||
description="The deployment name to use for the LLM service.", default=None
|
||||
)
|
||||
model_supports_json: bool | None = Field(
|
||||
description="Whether the model supports JSON output mode.", default=None
|
||||
)
|
||||
tokens_per_minute: int = Field(
|
||||
description="The number of tokens per minute to use for the LLM service.",
|
||||
default=defs.LLM_TOKENS_PER_MINUTE,
|
||||
)
|
||||
requests_per_minute: int = Field(
|
||||
description="The number of requests per minute to use for the LLM service.",
|
||||
default=defs.LLM_REQUESTS_PER_MINUTE,
|
||||
)
|
||||
max_retries: int = Field(
|
||||
description="The maximum number of retries to use for the LLM service.",
|
||||
default=defs.LLM_MAX_RETRIES,
|
||||
)
|
||||
max_retry_wait: float = Field(
|
||||
description="The maximum retry wait to use for the LLM service.",
|
||||
default=defs.LLM_MAX_RETRY_WAIT,
|
||||
)
|
||||
sleep_on_rate_limit_recommendation: bool = Field(
|
||||
description="Whether to sleep on rate limit recommendations.",
|
||||
default=defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION,
|
||||
)
|
||||
concurrent_requests: int = Field(
|
||||
description="Whether to use concurrent requests for the LLM service.",
|
||||
default=defs.LLM_CONCURRENT_REQUESTS,
|
||||
)
|
||||
responses: list[str | BaseModel] | None = Field(
|
||||
default=None, description="Static responses to use in mock mode."
|
||||
)
|
||||
@ -34,15 +34,15 @@ class LocalSearchConfig(BaseModel):
|
||||
description="The top k mapped relations.",
|
||||
default=defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS,
|
||||
)
|
||||
temperature: float | None = Field(
|
||||
temperature: float = Field(
|
||||
description="The temperature to use for token generation.",
|
||||
default=defs.LOCAL_SEARCH_LLM_TEMPERATURE,
|
||||
)
|
||||
top_p: float | None = Field(
|
||||
top_p: float = Field(
|
||||
description="The top-p value to use for token generation.",
|
||||
default=defs.LOCAL_SEARCH_LLM_TOP_P,
|
||||
)
|
||||
n: int | None = Field(
|
||||
n: int = Field(
|
||||
description="The number of completions to generate.",
|
||||
default=defs.LOCAL_SEARCH_LLM_N,
|
||||
)
|
||||
|
||||
@ -6,18 +6,18 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.enums import StorageType
|
||||
from graphrag.config.enums import OutputType
|
||||
|
||||
|
||||
class StorageConfig(BaseModel):
|
||||
"""The default configuration section for Storage."""
|
||||
class OutputConfig(BaseModel):
|
||||
"""The default configuration section for Output."""
|
||||
|
||||
type: StorageType = Field(
|
||||
description="The storage type to use.", default=defs.STORAGE_TYPE
|
||||
type: OutputType = Field(
|
||||
description="The output type to use.", default=defs.OUTPUT_TYPE
|
||||
)
|
||||
base_dir: str = Field(
|
||||
description="The base directory for the storage.",
|
||||
default=defs.STORAGE_BASE_DIR,
|
||||
description="The base directory for the output.",
|
||||
default=defs.OUTPUT_BASE_DIR,
|
||||
)
|
||||
connection_string: str | None = Field(
|
||||
description="The storage connection string to use.", default=None
|
||||
@ -1,21 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LLM Parameters model."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
|
||||
|
||||
class ParallelizationParameters(BaseModel):
|
||||
"""LLM Parameters model."""
|
||||
|
||||
stagger: float = Field(
|
||||
description="The stagger to use for the LLM service.",
|
||||
default=defs.PARALLELIZATION_STAGGER,
|
||||
)
|
||||
num_threads: int = Field(
|
||||
description="The number of threads to use for the LLM service.",
|
||||
default=defs.PARALLELIZATION_NUM_THREADS,
|
||||
)
|
||||
@ -5,13 +5,13 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.models.llm_config import LLMConfig
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
|
||||
|
||||
class SummarizeDescriptionsConfig(LLMConfig):
|
||||
class SummarizeDescriptionsConfig(BaseModel):
|
||||
"""Configuration section for description summarization."""
|
||||
|
||||
prompt: str | None = Field(
|
||||
@ -24,8 +24,14 @@ class SummarizeDescriptionsConfig(LLMConfig):
|
||||
strategy: dict | None = Field(
|
||||
description="The override strategy to use.", default=None
|
||||
)
|
||||
model_id: str = Field(
|
||||
description="The model ID to use for summarization.",
|
||||
default=defs.SUMMARIZE_MODEL_ID,
|
||||
)
|
||||
|
||||
def resolved_strategy(self, root_dir: str) -> dict:
|
||||
def resolved_strategy(
|
||||
self, root_dir: str, model_config: LanguageModelConfig
|
||||
) -> dict:
|
||||
"""Get the resolved description summarization strategy."""
|
||||
from graphrag.index.operations.summarize_descriptions import (
|
||||
SummarizeStrategyType,
|
||||
@ -33,11 +39,12 @@ class SummarizeDescriptionsConfig(LLMConfig):
|
||||
|
||||
return self.strategy or {
|
||||
"type": SummarizeStrategyType.graph_intelligence,
|
||||
"llm": self.llm.model_dump(),
|
||||
**self.parallelization.model_dump(),
|
||||
"summarize_prompt": (Path(root_dir) / self.prompt)
|
||||
.read_bytes()
|
||||
.decode(encoding="utf-8")
|
||||
"llm": model_config.model_dump(),
|
||||
"stagger": model_config.parallelization_stagger,
|
||||
"num_threads": model_config.parallelization_num_threads,
|
||||
"summarize_prompt": (Path(root_dir) / self.prompt).read_text(
|
||||
encoding="utf-8"
|
||||
)
|
||||
if self.prompt
|
||||
else None,
|
||||
"max_summary_length": self.max_length,
|
||||
|
||||
@ -3,14 +3,14 @@
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.enums import TextEmbeddingTarget
|
||||
from graphrag.config.models.llm_config import LLMConfig
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
|
||||
|
||||
class TextEmbeddingConfig(LLMConfig):
|
||||
class TextEmbeddingConfig(BaseModel):
|
||||
"""Configuration section for text embeddings."""
|
||||
|
||||
batch_size: int = Field(
|
||||
@ -21,18 +21,21 @@ class TextEmbeddingConfig(LLMConfig):
|
||||
default=defs.EMBEDDING_BATCH_MAX_TOKENS,
|
||||
)
|
||||
target: TextEmbeddingTarget = Field(
|
||||
description="The target to use. 'all' or 'required'.",
|
||||
description="The target to use. 'all', 'required', 'selected', or 'none'.",
|
||||
default=defs.EMBEDDING_TARGET,
|
||||
)
|
||||
skip: list[str] = Field(description="The specific embeddings to skip.", default=[])
|
||||
vector_store: dict | None = Field(
|
||||
description="The vector storage configuration", default=defs.VECTOR_STORE_DICT
|
||||
names: list[str] = Field(
|
||||
description="The specific embeddings to perform.", default=[]
|
||||
)
|
||||
strategy: dict | None = Field(
|
||||
description="The override strategy to use.", default=None
|
||||
)
|
||||
model_id: str = Field(
|
||||
description="The model ID to use for text embeddings.",
|
||||
default=defs.EMBEDDING_MODEL_ID,
|
||||
)
|
||||
|
||||
def resolved_strategy(self) -> dict:
|
||||
def resolved_strategy(self, model_config: LanguageModelConfig) -> dict:
|
||||
"""Get the resolved text embedding strategy."""
|
||||
from graphrag.index.operations.embed_text import (
|
||||
TextEmbedStrategyType,
|
||||
@ -40,8 +43,9 @@ class TextEmbeddingConfig(LLMConfig):
|
||||
|
||||
return self.strategy or {
|
||||
"type": TextEmbedStrategyType.openai,
|
||||
"llm": self.llm.model_dump(),
|
||||
**self.parallelization.model_dump(),
|
||||
"llm": model_config.model_dump(),
|
||||
"stagger": model_config.parallelization_stagger,
|
||||
"num_threads": model_config.parallelization_num_threads,
|
||||
"batch_size": self.batch_size,
|
||||
"batch_max_tokens": self.batch_max_tokens,
|
||||
}
|
||||
|
||||
78
graphrag/config/models/vector_store_config.py
Normal file
78
graphrag/config/models/vector_store_config.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.vector_stores.factory import VectorStoreType
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseModel):
|
||||
"""The default configuration section for Vector Store."""
|
||||
|
||||
type: str = Field(
|
||||
description="The vector store type to use.",
|
||||
default=defs.VECTOR_STORE_TYPE,
|
||||
)
|
||||
|
||||
db_uri: str | None = Field(description="The database URI to use.", default=None)
|
||||
|
||||
def _validate_db_uri(self) -> None:
|
||||
"""Validate the database URI."""
|
||||
if self.type == VectorStoreType.LanceDB.value and (
|
||||
self.db_uri is None or self.db_uri.strip() == ""
|
||||
):
|
||||
self.db_uri = defs.VECTOR_STORE_DB_URI
|
||||
|
||||
if self.type != VectorStoreType.LanceDB.value and (
|
||||
self.db_uri is not None and self.db_uri.strip() != ""
|
||||
):
|
||||
msg = "vector_store.db_uri is only used when vector_store.type == lancedb. Please rerun `graphrag init` and select the correct vector store type."
|
||||
raise ValueError(msg)
|
||||
|
||||
url: str | None = Field(
|
||||
description="The database URL when type == azure_ai_search.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def _validate_url(self) -> None:
|
||||
"""Validate the database URL."""
|
||||
if self.type == VectorStoreType.AzureAISearch and (
|
||||
self.url is None or self.url.strip() == ""
|
||||
):
|
||||
msg = "vector_store.url is required when vector_store.type == azure_ai_search. Please rerun `graphrag init` and select the correct vector store type."
|
||||
raise ValueError(msg)
|
||||
|
||||
if self.type != VectorStoreType.AzureAISearch and (
|
||||
self.url is not None and self.url.strip() != ""
|
||||
):
|
||||
msg = "vector_store.url is only used when vector_store.type == azure_ai_search. Please rerun `graphrag init` and select the correct vector store type."
|
||||
raise ValueError(msg)
|
||||
|
||||
api_key: str | None = Field(
|
||||
description="The database API key when type == azure_ai_search.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
audience: str | None = Field(
|
||||
description="The database audience when type == azure_ai_search.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
container_name: str = Field(
|
||||
description="The database name to use.",
|
||||
default=defs.VECTOR_STORE_CONTAINER_NAME,
|
||||
)
|
||||
|
||||
overwrite: bool = Field(
|
||||
description="Overwrite the existing data.", default=defs.VECTOR_STORE_OVERWRITE
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_model(self):
|
||||
"""Validate the model."""
|
||||
self._validate_db_uri()
|
||||
self._validate_url()
|
||||
return self
|
||||
@ -1,214 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Resolve timestamp variables in a path."""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from string import Template
|
||||
|
||||
from graphrag.config.enums import ReportingType, StorageType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.vector_stores.factory import VectorStoreType
|
||||
|
||||
|
||||
def _resolve_timestamp_path_with_value(path: str | Path, timestamp_value: str) -> Path:
|
||||
"""Resolve the timestamp in the path with the given timestamp value.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str | Path
|
||||
The path containing ${timestamp} variables to resolve.
|
||||
timestamp_value : str
|
||||
The timestamp value used to resolve the path.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
The path with ${timestamp} variables resolved to the provided timestamp value.
|
||||
"""
|
||||
template = Template(str(path))
|
||||
resolved_path = template.substitute(timestamp=timestamp_value)
|
||||
return Path(resolved_path)
|
||||
|
||||
|
||||
def _resolve_timestamp_path_with_dir(
|
||||
path: str | Path, pattern: re.Pattern[str]
|
||||
) -> Path:
|
||||
"""Resolve the timestamp in the path with the latest available timestamp directory value.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str | Path
|
||||
The path containing ${timestamp} variables to resolve.
|
||||
pattern : re.Pattern[str]
|
||||
The pattern to use to match the timestamp directories.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
The path with ${timestamp} variables resolved to the latest available timestamp directory value.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the parent directory expecting to contain timestamp directories does not exist or is not a directory.
|
||||
Or if no timestamp directories are found in the parent directory that match the pattern.
|
||||
"""
|
||||
path = Path(path)
|
||||
path_parts = path.parts
|
||||
parent_dir = Path(path_parts[0])
|
||||
found_timestamp_pattern = False
|
||||
for _, part in enumerate(path_parts[1:]):
|
||||
if part.lower() == "${timestamp}":
|
||||
found_timestamp_pattern = True
|
||||
break
|
||||
parent_dir = parent_dir / part
|
||||
|
||||
# Path not using timestamp layout.
|
||||
if not found_timestamp_pattern:
|
||||
return path
|
||||
|
||||
if not parent_dir.exists() or not parent_dir.is_dir():
|
||||
msg = f"Parent directory {parent_dir} does not exist or is not a directory."
|
||||
raise ValueError(msg)
|
||||
|
||||
timestamp_dirs = [
|
||||
d for d in parent_dir.iterdir() if d.is_dir() and pattern.match(d.name)
|
||||
]
|
||||
timestamp_dirs.sort(key=lambda d: d.name, reverse=True)
|
||||
if len(timestamp_dirs) == 0:
|
||||
msg = f"No timestamp directories found in {parent_dir} that match {pattern.pattern}."
|
||||
raise ValueError(msg)
|
||||
return _resolve_timestamp_path_with_value(path, timestamp_dirs[0].name)
|
||||
|
||||
|
||||
def _resolve_timestamp_path(
|
||||
path: str | Path,
|
||||
pattern_or_timestamp_value: re.Pattern[str] | str | None = None,
|
||||
) -> Path:
|
||||
r"""Timestamp path resolver.
|
||||
|
||||
Resolve the timestamp in the path with the given timestamp value or
|
||||
with the latest available timestamp directory matching the given pattern.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str | Path
|
||||
The path containing ${timestamp} variables to resolve.
|
||||
pattern_or_timestamp_value : re.Pattern[str] | str, default=re.compile(r"^\d{8}-\d{6}$")
|
||||
The pattern to use to match the timestamp directories or the timestamp value to use.
|
||||
If a string is provided, the path will be resolved with the given string value.
|
||||
Otherwise, the path will be resolved with the latest available timestamp directory
|
||||
that matches the given pattern.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
The path with ${timestamp} variables resolved to the provided timestamp value or
|
||||
the latest available timestamp directory.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the parent directory expecting to contain timestamp directories does not exist or is not a directory.
|
||||
Or if no timestamp directories are found in the parent directory that match the pattern.
|
||||
"""
|
||||
if not pattern_or_timestamp_value:
|
||||
pattern_or_timestamp_value = re.compile(r"^\d{8}-\d{6}$")
|
||||
if isinstance(pattern_or_timestamp_value, str):
|
||||
return _resolve_timestamp_path_with_value(path, pattern_or_timestamp_value)
|
||||
return _resolve_timestamp_path_with_dir(path, pattern_or_timestamp_value)
|
||||
|
||||
|
||||
def resolve_path(
|
||||
path_to_resolve: Path | str,
|
||||
root_dir: Path | str | None = None,
|
||||
pattern_or_timestamp_value: re.Pattern[str] | str | None = None,
|
||||
) -> Path:
|
||||
"""Resolve the path.
|
||||
|
||||
Resolves any timestamp variables by either using the provided timestamp value if string or
|
||||
by looking up the latest available timestamp directory that matches the given pattern.
|
||||
Resolves the path against the root directory if provided.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path_to_resolve : Path | str
|
||||
The path to resolve.
|
||||
root_dir : Path | str | None default=None
|
||||
The root directory to resolve the path from, if provided.
|
||||
pattern_or_timestamp_value : re.Pattern[str] | str, default=None
|
||||
The pattern to use to match the timestamp directories or the timestamp value to use.
|
||||
If a string is provided, the path will be resolved with the given string value.
|
||||
Otherwise, the path will be resolved with the latest available timestamp directory
|
||||
that matches the given pattern.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
The resolved path.
|
||||
"""
|
||||
if root_dir:
|
||||
path_to_resolve = (Path(root_dir) / path_to_resolve).resolve()
|
||||
else:
|
||||
path_to_resolve = Path(path_to_resolve)
|
||||
return _resolve_timestamp_path(path_to_resolve, pattern_or_timestamp_value)
|
||||
|
||||
|
||||
def resolve_paths(
|
||||
config: GraphRagConfig,
|
||||
pattern_or_timestamp_value: re.Pattern[str] | str | None = None,
|
||||
) -> None:
|
||||
"""Resolve storage and reporting paths in the configuration for local file handling.
|
||||
|
||||
Resolves any timestamp variables in the configuration paths by either using the provided timestamp value if string or
|
||||
by looking up the latest available timestamp directory that matches the given pattern.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : GraphRagConfig
|
||||
The configuration to resolve the paths in.
|
||||
pattern_or_timestamp_value : re.Pattern[str] | str, default=None
|
||||
The pattern to use to match the timestamp directories or the timestamp value to use.
|
||||
If a string is provided, the path will be resolved with the given string value.
|
||||
Otherwise, the path will be resolved with the latest available timestamp directory
|
||||
that matches the given pattern.
|
||||
"""
|
||||
if config.storage.type == StorageType.file:
|
||||
config.storage.base_dir = str(
|
||||
resolve_path(
|
||||
config.storage.base_dir,
|
||||
config.root_dir,
|
||||
pattern_or_timestamp_value,
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
config.update_index_storage
|
||||
and config.update_index_storage.type == StorageType.file
|
||||
):
|
||||
config.update_index_storage.base_dir = str(
|
||||
resolve_path(
|
||||
config.update_index_storage.base_dir,
|
||||
config.root_dir,
|
||||
pattern_or_timestamp_value,
|
||||
)
|
||||
)
|
||||
|
||||
if config.reporting.type == ReportingType.file:
|
||||
config.reporting.base_dir = str(
|
||||
resolve_path(
|
||||
config.reporting.base_dir,
|
||||
config.root_dir,
|
||||
pattern_or_timestamp_value,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
|
||||
# TODO: remove the type ignore annotations below once the new config engine has been refactored
|
||||
vector_store_type = config.embeddings.vector_store["type"] # type: ignore
|
||||
if vector_store_type == VectorStoreType.LanceDB:
|
||||
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
|
||||
lancedb_dir = Path(config.root_dir).resolve() / db_uri
|
||||
config.embeddings.vector_store["db_uri"] = str(lancedb_dir) # type: ignore
|
||||
@ -1,4 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""The Indexing Engine config typing package root."""
|
||||
@ -1,105 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'PipelineCacheConfig', 'PipelineFileCacheConfig', 'PipelineMemoryCacheConfig', 'PipelineBlobCacheConfig', 'PipelineCosmosDBCacheConfig' models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generic, Literal, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphrag.config.enums import CacheType
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PipelineCacheConfig(BaseModel, Generic[T]):
|
||||
"""Represent the cache configuration for the pipeline."""
|
||||
|
||||
type: T
|
||||
|
||||
|
||||
class PipelineFileCacheConfig(PipelineCacheConfig[Literal[CacheType.file]]):
|
||||
"""Represent the file cache configuration for the pipeline."""
|
||||
|
||||
type: Literal[CacheType.file] = CacheType.file
|
||||
"""The type of cache."""
|
||||
|
||||
base_dir: str | None = Field(
|
||||
description="The base directory for the cache.", default=None
|
||||
)
|
||||
"""The base directory for the cache."""
|
||||
|
||||
|
||||
class PipelineMemoryCacheConfig(PipelineCacheConfig[Literal[CacheType.memory]]):
|
||||
"""Represent the memory cache configuration for the pipeline."""
|
||||
|
||||
type: Literal[CacheType.memory] = CacheType.memory
|
||||
"""The type of cache."""
|
||||
|
||||
|
||||
class PipelineNoneCacheConfig(PipelineCacheConfig[Literal[CacheType.none]]):
|
||||
"""Represent the none cache configuration for the pipeline."""
|
||||
|
||||
type: Literal[CacheType.none] = CacheType.none
|
||||
"""The type of cache."""
|
||||
|
||||
|
||||
class PipelineBlobCacheConfig(PipelineCacheConfig[Literal[CacheType.blob]]):
|
||||
"""Represents the blob cache configuration for the pipeline."""
|
||||
|
||||
type: Literal[CacheType.blob] = CacheType.blob
|
||||
"""The type of cache."""
|
||||
|
||||
base_dir: str | None = Field(
|
||||
description="The base directory for the cache.", default=None
|
||||
)
|
||||
"""The base directory for the cache."""
|
||||
|
||||
connection_string: str | None = Field(
|
||||
description="The blob cache connection string for the cache.", default=None
|
||||
)
|
||||
"""The blob cache connection string for the cache."""
|
||||
|
||||
container_name: str = Field(description="The container name for cache", default="")
|
||||
"""The container name for cache"""
|
||||
|
||||
storage_account_blob_url: str | None = Field(
|
||||
description="The storage account blob url for cache", default=None
|
||||
)
|
||||
"""The storage account blob url for cache"""
|
||||
|
||||
|
||||
class PipelineCosmosDBCacheConfig(PipelineCacheConfig[Literal[CacheType.cosmosdb]]):
|
||||
"""Represents the cosmosdb cache configuration for the pipeline."""
|
||||
|
||||
type: Literal[CacheType.cosmosdb] = CacheType.cosmosdb
|
||||
"""The type of cache."""
|
||||
|
||||
base_dir: str | None = Field(
|
||||
description="The cosmosdb database name for the cache.", default=None
|
||||
)
|
||||
"""The cosmosdb database name for the cache."""
|
||||
|
||||
container_name: str = Field(description="The container name for cache.", default="")
|
||||
"""The container name for cache."""
|
||||
|
||||
connection_string: str | None = Field(
|
||||
description="The cosmosdb primary key for the cache.", default=None
|
||||
)
|
||||
"""The cosmosdb primary key for the cache."""
|
||||
|
||||
cosmosdb_account_url: str | None = Field(
|
||||
description="The cosmosdb account url for cache", default=None
|
||||
)
|
||||
"""The cosmosdb account url for cache"""
|
||||
|
||||
|
||||
PipelineCacheConfigTypes = (
|
||||
PipelineFileCacheConfig
|
||||
| PipelineMemoryCacheConfig
|
||||
| PipelineBlobCacheConfig
|
||||
| PipelineNoneCacheConfig
|
||||
| PipelineCosmosDBCacheConfig
|
||||
)
|
||||
@ -1,110 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'PipelineInputConfig', 'PipelineCSVInputConfig' and 'PipelineTextInputConfig' models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generic, Literal, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphrag.config.enums import InputFileType, InputType
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PipelineInputConfig(BaseModel, Generic[T]):
|
||||
"""Represent the configuration for an input."""
|
||||
|
||||
file_type: T
|
||||
"""The file type of input."""
|
||||
|
||||
type: InputType | None = Field(
|
||||
description="The input type to use.",
|
||||
default=None,
|
||||
)
|
||||
"""The input type to use."""
|
||||
|
||||
connection_string: str | None = Field(
|
||||
description="The blob cache connection string for the input files.",
|
||||
default=None,
|
||||
)
|
||||
"""The blob cache connection string for the input files."""
|
||||
|
||||
storage_account_blob_url: str | None = Field(
|
||||
description="The storage account blob url for the input files.", default=None
|
||||
)
|
||||
"""The storage account blob url for the input files."""
|
||||
|
||||
container_name: str | None = Field(
|
||||
description="The container name for input files.", default=None
|
||||
)
|
||||
"""The container name for the input files."""
|
||||
|
||||
base_dir: str | None = Field(
|
||||
description="The base directory for the input files.", default=None
|
||||
)
|
||||
"""The base directory for the input files."""
|
||||
|
||||
file_pattern: str = Field(description="The regex file pattern for the input files.")
|
||||
"""The regex file pattern for the input files."""
|
||||
|
||||
file_filter: dict[str, str] | None = Field(
|
||||
description="The optional file filter for the input files.", default=None
|
||||
)
|
||||
"""The optional file filter for the input files."""
|
||||
|
||||
encoding: str | None = Field(
|
||||
description="The encoding for the input files.", default=None
|
||||
)
|
||||
"""The encoding for the input files."""
|
||||
|
||||
|
||||
class PipelineCSVInputConfig(PipelineInputConfig[Literal[InputFileType.csv]]):
|
||||
"""Represent the configuration for a CSV input."""
|
||||
|
||||
file_type: Literal[InputFileType.csv] = InputFileType.csv
|
||||
|
||||
source_column: str | None = Field(
|
||||
description="The column to use as the source of the document.", default=None
|
||||
)
|
||||
"""The column to use as the source of the document."""
|
||||
|
||||
timestamp_column: str | None = Field(
|
||||
description="The column to use as the timestamp of the document.", default=None
|
||||
)
|
||||
"""The column to use as the timestamp of the document."""
|
||||
|
||||
timestamp_format: str | None = Field(
|
||||
description="The format of the timestamp column, so it can be parsed correctly.",
|
||||
default=None,
|
||||
)
|
||||
"""The format of the timestamp column, so it can be parsed correctly."""
|
||||
|
||||
text_column: str | None = Field(
|
||||
description="The column to use as the text of the document.", default=None
|
||||
)
|
||||
"""The column to use as the text of the document."""
|
||||
|
||||
title_column: str | None = Field(
|
||||
description="The column to use as the title of the document.", default=None
|
||||
)
|
||||
"""The column to use as the title of the document."""
|
||||
|
||||
|
||||
class PipelineTextInputConfig(PipelineInputConfig[Literal[InputFileType.text]]):
|
||||
"""Represent the configuration for a text input."""
|
||||
|
||||
file_type: Literal[InputFileType.text] = InputFileType.text
|
||||
|
||||
# Text Specific
|
||||
title_text_length: int | None = Field(
|
||||
description="Number of characters to use from the text as the title.",
|
||||
default=None,
|
||||
)
|
||||
"""Number of characters to use from the text as the title."""
|
||||
|
||||
|
||||
PipelineInputConfigTypes = PipelineCSVInputConfig | PipelineTextInputConfig
|
||||
"""Represent the types of inputs that can be used in a pipeline."""
|
||||
@ -1,66 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'PipelineConfig' model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from devtools import pformat
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphrag.index.config.cache import PipelineCacheConfigTypes
|
||||
from graphrag.index.config.input import PipelineInputConfigTypes
|
||||
from graphrag.index.config.reporting import PipelineReportingConfigTypes
|
||||
from graphrag.index.config.storage import PipelineStorageConfigTypes
|
||||
from graphrag.index.config.workflow import PipelineWorkflowReference
|
||||
|
||||
|
||||
class PipelineConfig(BaseModel):
|
||||
"""Represent the configuration for a pipeline."""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Get a string representation."""
|
||||
return pformat(self, highlight=False)
|
||||
|
||||
def __str__(self):
|
||||
"""Get a string representation."""
|
||||
return str(self.model_dump_json(indent=4))
|
||||
|
||||
extends: list[str] | str | None = Field(
|
||||
description="Extends another pipeline configuration", default=None
|
||||
)
|
||||
"""Extends another pipeline configuration"""
|
||||
|
||||
input: PipelineInputConfigTypes | None = Field(
|
||||
default=None, discriminator="file_type"
|
||||
)
|
||||
"""The input configuration for the pipeline."""
|
||||
|
||||
reporting: PipelineReportingConfigTypes | None = Field(
|
||||
default=None, discriminator="type"
|
||||
)
|
||||
"""The reporting configuration for the pipeline."""
|
||||
|
||||
storage: PipelineStorageConfigTypes | None = Field(
|
||||
default=None, discriminator="type"
|
||||
)
|
||||
"""The storage configuration for the pipeline."""
|
||||
|
||||
update_index_storage: PipelineStorageConfigTypes | None = Field(
|
||||
default=None, discriminator="type"
|
||||
)
|
||||
"""The storage configuration for the updated index."""
|
||||
|
||||
cache: PipelineCacheConfigTypes | None = Field(default=None, discriminator="type")
|
||||
"""The cache configuration for the pipeline."""
|
||||
|
||||
root_dir: str | None = Field(
|
||||
description="The root directory for the pipeline. All other paths will be based on this root_dir.",
|
||||
default=None,
|
||||
)
|
||||
"""The root directory for the pipeline."""
|
||||
|
||||
workflows: list[PipelineWorkflowReference] = Field(
|
||||
description="The workflows for the pipeline.", default_factory=list
|
||||
)
|
||||
"""The workflows for the pipeline."""
|
||||
@ -1,103 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'PipelineStorageConfig', 'PipelineFileStorageConfig','PipelineMemoryStorageConfig', 'PipelineBlobStorageConfig', and 'PipelineCosmosDBStorageConfig' models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generic, Literal, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphrag.config.enums import StorageType
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PipelineStorageConfig(BaseModel, Generic[T]):
|
||||
"""Represent the storage configuration for the pipeline."""
|
||||
|
||||
type: T
|
||||
|
||||
|
||||
class PipelineFileStorageConfig(PipelineStorageConfig[Literal[StorageType.file]]):
|
||||
"""Represent the file storage configuration for the pipeline."""
|
||||
|
||||
type: Literal[StorageType.file] = StorageType.file
|
||||
"""The type of storage."""
|
||||
|
||||
base_dir: str | None = Field(
|
||||
description="The base directory for the storage.", default=None
|
||||
)
|
||||
"""The base directory for the storage."""
|
||||
|
||||
|
||||
class PipelineMemoryStorageConfig(PipelineStorageConfig[Literal[StorageType.memory]]):
|
||||
"""Represent the memory storage configuration for the pipeline."""
|
||||
|
||||
type: Literal[StorageType.memory] = StorageType.memory
|
||||
"""The type of storage."""
|
||||
|
||||
|
||||
class PipelineBlobStorageConfig(PipelineStorageConfig[Literal[StorageType.blob]]):
|
||||
"""Represents the blob storage configuration for the pipeline."""
|
||||
|
||||
type: Literal[StorageType.blob] = StorageType.blob
|
||||
"""The type of storage."""
|
||||
|
||||
connection_string: str | None = Field(
|
||||
description="The blob storage connection string for the storage.", default=None
|
||||
)
|
||||
"""The blob storage connection string for the storage."""
|
||||
|
||||
container_name: str = Field(
|
||||
description="The container name for storage", default=""
|
||||
)
|
||||
"""The container name for storage."""
|
||||
|
||||
base_dir: str | None = Field(
|
||||
description="The base directory for the storage.", default=None
|
||||
)
|
||||
"""The base directory for the storage."""
|
||||
|
||||
storage_account_blob_url: str | None = Field(
|
||||
description="The storage account blob url.", default=None
|
||||
)
|
||||
"""The storage account blob url."""
|
||||
|
||||
|
||||
class PipelineCosmosDBStorageConfig(
|
||||
PipelineStorageConfig[Literal[StorageType.cosmosdb]]
|
||||
):
|
||||
"""Represents the cosmosdb storage configuration for the pipeline."""
|
||||
|
||||
type: Literal[StorageType.cosmosdb] = StorageType.cosmosdb
|
||||
"""The type of storage."""
|
||||
|
||||
connection_string: str | None = Field(
|
||||
description="The cosmosdb storage primary key for the storage.", default=None
|
||||
)
|
||||
"""The cosmosdb storage primary key for the storage."""
|
||||
|
||||
container_name: str = Field(
|
||||
description="The container name for storage", default=""
|
||||
)
|
||||
"""The container name for storage."""
|
||||
|
||||
base_dir: str | None = Field(
|
||||
description="The base directory for the storage.", default=None
|
||||
)
|
||||
"""The base directory for the storage."""
|
||||
|
||||
cosmosdb_account_url: str | None = Field(
|
||||
description="The cosmosdb account url.", default=None
|
||||
)
|
||||
"""The cosmosdb account url."""
|
||||
|
||||
|
||||
PipelineStorageConfigTypes = (
|
||||
PipelineFileStorageConfig
|
||||
| PipelineMemoryStorageConfig
|
||||
| PipelineBlobStorageConfig
|
||||
| PipelineCosmosDBStorageConfig
|
||||
)
|
||||
@ -1,25 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'PipelineWorkflowReference' model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
PipelineWorkflowConfig = dict[str, Any]
|
||||
"""Represent a configuration for a workflow."""
|
||||
|
||||
|
||||
class PipelineWorkflowReference(BaseModel):
|
||||
"""Represent a reference to a workflow, and can optionally be the workflow itself."""
|
||||
|
||||
name: str | None = Field(description="Name of the workflow.", default=None)
|
||||
"""Name of the workflow."""
|
||||
|
||||
config: PipelineWorkflowConfig | None = Field(
|
||||
description="The optional configuration for the workflow.", default=None
|
||||
)
|
||||
"""The optional configuration for the workflow."""
|
||||
@ -1,456 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Default configuration methods definition."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag.config.enums import (
|
||||
CacheType,
|
||||
InputFileType,
|
||||
ReportingType,
|
||||
StorageType,
|
||||
)
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.models.storage_config import StorageConfig
|
||||
from graphrag.index.config.cache import (
|
||||
PipelineBlobCacheConfig,
|
||||
PipelineCacheConfigTypes,
|
||||
PipelineCosmosDBCacheConfig,
|
||||
PipelineFileCacheConfig,
|
||||
PipelineMemoryCacheConfig,
|
||||
PipelineNoneCacheConfig,
|
||||
)
|
||||
from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings
|
||||
from graphrag.index.config.input import (
|
||||
PipelineCSVInputConfig,
|
||||
PipelineInputConfigTypes,
|
||||
PipelineTextInputConfig,
|
||||
)
|
||||
from graphrag.index.config.pipeline import (
|
||||
PipelineConfig,
|
||||
)
|
||||
from graphrag.index.config.reporting import (
|
||||
PipelineBlobReportingConfig,
|
||||
PipelineConsoleReportingConfig,
|
||||
PipelineFileReportingConfig,
|
||||
PipelineReportingConfigTypes,
|
||||
)
|
||||
from graphrag.index.config.storage import (
|
||||
PipelineBlobStorageConfig,
|
||||
PipelineCosmosDBStorageConfig,
|
||||
PipelineFileStorageConfig,
|
||||
PipelineMemoryStorageConfig,
|
||||
PipelineStorageConfigTypes,
|
||||
)
|
||||
from graphrag.index.config.workflow import (
|
||||
PipelineWorkflowReference,
|
||||
)
|
||||
from graphrag.index.workflows import (
|
||||
compute_communities,
|
||||
create_base_text_units,
|
||||
create_final_communities,
|
||||
create_final_community_reports,
|
||||
create_final_covariates,
|
||||
create_final_documents,
|
||||
create_final_entities,
|
||||
create_final_nodes,
|
||||
create_final_relationships,
|
||||
create_final_text_units,
|
||||
extract_graph,
|
||||
generate_text_embeddings,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
builtin_document_attributes: set[str] = {
|
||||
"id",
|
||||
"source",
|
||||
"text",
|
||||
"title",
|
||||
"timestamp",
|
||||
"year",
|
||||
"month",
|
||||
"day",
|
||||
"hour",
|
||||
"minute",
|
||||
"second",
|
||||
}
|
||||
|
||||
|
||||
def create_pipeline_config(settings: GraphRagConfig, verbose=False) -> PipelineConfig:
|
||||
"""Get the default config for the pipeline."""
|
||||
# relative to the root_dir
|
||||
if verbose:
|
||||
_log_llm_settings(settings)
|
||||
|
||||
skip_workflows = settings.skip_workflows
|
||||
embedded_fields = get_embedded_fields(settings)
|
||||
covariates_enabled = (
|
||||
settings.claim_extraction.enabled
|
||||
and create_final_covariates not in skip_workflows
|
||||
)
|
||||
|
||||
result = PipelineConfig(
|
||||
root_dir=settings.root_dir,
|
||||
input=_get_pipeline_input_config(settings),
|
||||
reporting=_get_reporting_config(settings),
|
||||
storage=_get_storage_config(settings, settings.storage),
|
||||
update_index_storage=_get_storage_config(
|
||||
settings, settings.update_index_storage
|
||||
),
|
||||
cache=_get_cache_config(settings),
|
||||
workflows=[
|
||||
*_document_workflows(settings),
|
||||
*_text_unit_workflows(settings, covariates_enabled),
|
||||
*_graph_workflows(settings),
|
||||
*_community_workflows(settings, covariates_enabled),
|
||||
*(_covariate_workflows(settings) if covariates_enabled else []),
|
||||
*(_embeddings_workflows(settings, embedded_fields)),
|
||||
],
|
||||
)
|
||||
|
||||
# Remove any workflows that were specified to be skipped
|
||||
log.info("skipping workflows %s", ",".join(skip_workflows))
|
||||
result.workflows = [w for w in result.workflows if w.name not in skip_workflows]
|
||||
return result
|
||||
|
||||
|
||||
def _log_llm_settings(settings: GraphRagConfig) -> None:
|
||||
log.info(
|
||||
"Using LLM Config %s",
|
||||
json.dumps(
|
||||
{**settings.entity_extraction.llm.model_dump(), "api_key": "*****"},
|
||||
indent=4,
|
||||
),
|
||||
)
|
||||
log.info(
|
||||
"Using Embeddings Config %s",
|
||||
json.dumps(
|
||||
{**settings.embeddings.llm.model_dump(), "api_key": "*****"}, indent=4
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _document_workflows(
|
||||
settings: GraphRagConfig,
|
||||
) -> list[PipelineWorkflowReference]:
|
||||
return [
|
||||
PipelineWorkflowReference(
|
||||
name=create_final_documents,
|
||||
config={
|
||||
"document_attribute_columns": list(
|
||||
{*(settings.input.document_attribute_columns)}
|
||||
- builtin_document_attributes
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _text_unit_workflows(
|
||||
settings: GraphRagConfig,
|
||||
covariates_enabled: bool,
|
||||
) -> list[PipelineWorkflowReference]:
|
||||
return [
|
||||
PipelineWorkflowReference(
|
||||
name=create_base_text_units,
|
||||
config={
|
||||
"chunks": settings.chunks,
|
||||
"snapshot_transient": settings.snapshots.transient,
|
||||
},
|
||||
),
|
||||
PipelineWorkflowReference(
|
||||
name=create_final_text_units,
|
||||
config={
|
||||
"covariates_enabled": covariates_enabled,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference]:
|
||||
return [
|
||||
PipelineWorkflowReference(
|
||||
name=extract_graph,
|
||||
config={
|
||||
"snapshot_graphml": settings.snapshots.graphml,
|
||||
"snapshot_transient": settings.snapshots.transient,
|
||||
"entity_extract": {
|
||||
**settings.entity_extraction.parallelization.model_dump(),
|
||||
"async_mode": settings.entity_extraction.async_mode,
|
||||
"strategy": settings.entity_extraction.resolved_strategy(
|
||||
settings.root_dir, settings.encoding_model
|
||||
),
|
||||
"entity_types": settings.entity_extraction.entity_types,
|
||||
},
|
||||
"summarize_descriptions": {
|
||||
**settings.summarize_descriptions.parallelization.model_dump(),
|
||||
"async_mode": settings.summarize_descriptions.async_mode,
|
||||
"strategy": settings.summarize_descriptions.resolved_strategy(
|
||||
settings.root_dir,
|
||||
),
|
||||
},
|
||||
},
|
||||
),
|
||||
PipelineWorkflowReference(
|
||||
name=compute_communities,
|
||||
config={
|
||||
"cluster_graph": settings.cluster_graph,
|
||||
"snapshot_transient": settings.snapshots.transient,
|
||||
},
|
||||
),
|
||||
PipelineWorkflowReference(
|
||||
name=create_final_entities,
|
||||
config={},
|
||||
),
|
||||
PipelineWorkflowReference(
|
||||
name=create_final_relationships,
|
||||
config={},
|
||||
),
|
||||
PipelineWorkflowReference(
|
||||
name=create_final_nodes,
|
||||
config={
|
||||
"layout_enabled": settings.umap.enabled,
|
||||
"embed_graph": settings.embed_graph,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _community_workflows(
|
||||
settings: GraphRagConfig, covariates_enabled: bool
|
||||
) -> list[PipelineWorkflowReference]:
|
||||
return [
|
||||
PipelineWorkflowReference(name=create_final_communities),
|
||||
PipelineWorkflowReference(
|
||||
name=create_final_community_reports,
|
||||
config={
|
||||
"covariates_enabled": covariates_enabled,
|
||||
"create_community_reports": {
|
||||
**settings.community_reports.parallelization.model_dump(),
|
||||
"async_mode": settings.community_reports.async_mode,
|
||||
"strategy": settings.community_reports.resolved_strategy(
|
||||
settings.root_dir
|
||||
),
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _covariate_workflows(
|
||||
settings: GraphRagConfig,
|
||||
) -> list[PipelineWorkflowReference]:
|
||||
return [
|
||||
PipelineWorkflowReference(
|
||||
name=create_final_covariates,
|
||||
config={
|
||||
"claim_extract": {
|
||||
**settings.claim_extraction.parallelization.model_dump(),
|
||||
"strategy": settings.claim_extraction.resolved_strategy(
|
||||
settings.root_dir, settings.encoding_model
|
||||
),
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _embeddings_workflows(
|
||||
settings: GraphRagConfig, embedded_fields: set[str]
|
||||
) -> list[PipelineWorkflowReference]:
|
||||
return [
|
||||
PipelineWorkflowReference(
|
||||
name=generate_text_embeddings,
|
||||
config={
|
||||
"snapshot_embeddings": settings.snapshots.embeddings,
|
||||
"text_embed": get_embedding_settings(settings.embeddings),
|
||||
"embedded_fields": embedded_fields,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _get_pipeline_input_config(
|
||||
settings: GraphRagConfig,
|
||||
) -> PipelineInputConfigTypes:
|
||||
file_type = settings.input.file_type
|
||||
match file_type:
|
||||
case InputFileType.csv:
|
||||
return PipelineCSVInputConfig(
|
||||
base_dir=settings.input.base_dir,
|
||||
file_pattern=settings.input.file_pattern,
|
||||
encoding=settings.input.encoding,
|
||||
source_column=settings.input.source_column,
|
||||
timestamp_column=settings.input.timestamp_column,
|
||||
timestamp_format=settings.input.timestamp_format,
|
||||
text_column=settings.input.text_column,
|
||||
title_column=settings.input.title_column,
|
||||
type=settings.input.type,
|
||||
connection_string=settings.input.connection_string,
|
||||
storage_account_blob_url=settings.input.storage_account_blob_url,
|
||||
container_name=settings.input.container_name,
|
||||
)
|
||||
case InputFileType.text:
|
||||
return PipelineTextInputConfig(
|
||||
base_dir=settings.input.base_dir,
|
||||
file_pattern=settings.input.file_pattern,
|
||||
encoding=settings.input.encoding,
|
||||
type=settings.input.type,
|
||||
connection_string=settings.input.connection_string,
|
||||
storage_account_blob_url=settings.input.storage_account_blob_url,
|
||||
container_name=settings.input.container_name,
|
||||
)
|
||||
case _:
|
||||
msg = f"Unknown input type: {file_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def _get_reporting_config(
|
||||
settings: GraphRagConfig,
|
||||
) -> PipelineReportingConfigTypes:
|
||||
"""Get the reporting config from the settings."""
|
||||
match settings.reporting.type:
|
||||
case ReportingType.file:
|
||||
# relative to the root_dir
|
||||
return PipelineFileReportingConfig(base_dir=settings.reporting.base_dir)
|
||||
case ReportingType.blob:
|
||||
connection_string = settings.reporting.connection_string
|
||||
storage_account_blob_url = settings.reporting.storage_account_blob_url
|
||||
container_name = settings.reporting.container_name
|
||||
if container_name is None:
|
||||
msg = "Container name must be provided for blob reporting."
|
||||
raise ValueError(msg)
|
||||
if connection_string is None and storage_account_blob_url is None:
|
||||
msg = "Connection string or storage account blob url must be provided for blob reporting."
|
||||
raise ValueError(msg)
|
||||
return PipelineBlobReportingConfig(
|
||||
connection_string=connection_string,
|
||||
container_name=container_name,
|
||||
base_dir=settings.reporting.base_dir,
|
||||
storage_account_blob_url=storage_account_blob_url,
|
||||
)
|
||||
case ReportingType.console:
|
||||
return PipelineConsoleReportingConfig()
|
||||
case _:
|
||||
# relative to the root_dir
|
||||
return PipelineFileReportingConfig(base_dir=settings.reporting.base_dir)
|
||||
|
||||
|
||||
def _get_storage_config(
|
||||
settings: GraphRagConfig,
|
||||
storage_settings: StorageConfig | None,
|
||||
) -> PipelineStorageConfigTypes | None:
|
||||
"""Get the storage type from the settings."""
|
||||
if not storage_settings:
|
||||
return None
|
||||
root_dir = settings.root_dir
|
||||
match storage_settings.type:
|
||||
case StorageType.memory:
|
||||
return PipelineMemoryStorageConfig()
|
||||
case StorageType.file:
|
||||
# relative to the root_dir
|
||||
base_dir = storage_settings.base_dir
|
||||
if base_dir is None:
|
||||
msg = "Base directory must be provided for file storage."
|
||||
raise ValueError(msg)
|
||||
return PipelineFileStorageConfig(base_dir=str(Path(root_dir) / base_dir))
|
||||
case StorageType.blob:
|
||||
connection_string = storage_settings.connection_string
|
||||
storage_account_blob_url = storage_settings.storage_account_blob_url
|
||||
container_name = storage_settings.container_name
|
||||
if container_name is None:
|
||||
msg = "Container name must be provided for blob storage."
|
||||
raise ValueError(msg)
|
||||
if connection_string is None and storage_account_blob_url is None:
|
||||
msg = "Connection string or storage account blob url must be provided for blob storage."
|
||||
raise ValueError(msg)
|
||||
return PipelineBlobStorageConfig(
|
||||
connection_string=connection_string,
|
||||
container_name=container_name,
|
||||
base_dir=storage_settings.base_dir,
|
||||
storage_account_blob_url=storage_account_blob_url,
|
||||
)
|
||||
case StorageType.cosmosdb:
|
||||
cosmosdb_account_url = storage_settings.cosmosdb_account_url
|
||||
connection_string = storage_settings.connection_string
|
||||
base_dir = storage_settings.base_dir
|
||||
container_name = storage_settings.container_name
|
||||
if connection_string is None and cosmosdb_account_url is None:
|
||||
msg = "Connection string or cosmosDB account url must be provided for cosmosdb storage."
|
||||
raise ValueError(msg)
|
||||
if base_dir is None:
|
||||
msg = "Base directory must be provided for cosmosdb storage."
|
||||
raise ValueError(msg)
|
||||
if container_name is None:
|
||||
msg = "Container name must be provided for cosmosdb storage."
|
||||
raise ValueError(msg)
|
||||
return PipelineCosmosDBStorageConfig(
|
||||
cosmosdb_account_url=cosmosdb_account_url,
|
||||
connection_string=connection_string,
|
||||
base_dir=storage_settings.base_dir,
|
||||
container_name=container_name,
|
||||
)
|
||||
case _:
|
||||
# relative to the root_dir
|
||||
base_dir = storage_settings.base_dir
|
||||
if base_dir is None:
|
||||
msg = "Base directory must be provided for file storage."
|
||||
raise ValueError(msg)
|
||||
return PipelineFileStorageConfig(base_dir=str(Path(root_dir) / base_dir))
|
||||
|
||||
|
||||
def _get_cache_config(
|
||||
settings: GraphRagConfig,
|
||||
) -> PipelineCacheConfigTypes:
|
||||
"""Get the cache type from the settings."""
|
||||
match settings.cache.type:
|
||||
case CacheType.memory:
|
||||
return PipelineMemoryCacheConfig()
|
||||
case CacheType.file:
|
||||
# relative to root dir
|
||||
return PipelineFileCacheConfig(base_dir=settings.cache.base_dir)
|
||||
case CacheType.none:
|
||||
return PipelineNoneCacheConfig()
|
||||
case CacheType.blob:
|
||||
connection_string = settings.cache.connection_string
|
||||
storage_account_blob_url = settings.cache.storage_account_blob_url
|
||||
container_name = settings.cache.container_name
|
||||
if container_name is None:
|
||||
msg = "Container name must be provided for blob cache."
|
||||
raise ValueError(msg)
|
||||
if connection_string is None and storage_account_blob_url is None:
|
||||
msg = "Connection string or storage account blob url must be provided for blob cache."
|
||||
raise ValueError(msg)
|
||||
return PipelineBlobCacheConfig(
|
||||
connection_string=connection_string,
|
||||
container_name=container_name,
|
||||
base_dir=settings.cache.base_dir,
|
||||
storage_account_blob_url=storage_account_blob_url,
|
||||
)
|
||||
case CacheType.cosmosdb:
|
||||
cosmosdb_account_url = settings.cache.cosmosdb_account_url
|
||||
connection_string = settings.cache.connection_string
|
||||
base_dir = settings.cache.base_dir
|
||||
container_name = settings.cache.container_name
|
||||
if base_dir is None:
|
||||
msg = "Base directory must be provided for cosmosdb cache."
|
||||
raise ValueError(msg)
|
||||
if container_name is None:
|
||||
msg = "Container name must be provided for cosmosdb cache."
|
||||
raise ValueError(msg)
|
||||
if connection_string is None and cosmosdb_account_url is None:
|
||||
msg = "Connection string or cosmosDB account url must be provided for cosmosdb cache."
|
||||
raise ValueError(msg)
|
||||
return PipelineCosmosDBCacheConfig(
|
||||
cosmosdb_account_url=cosmosdb_account_url,
|
||||
connection_string=connection_string,
|
||||
base_dir=base_dir,
|
||||
container_name=container_name,
|
||||
)
|
||||
case _:
|
||||
# relative to root dir
|
||||
return PipelineFileCacheConfig(base_dir="./cache")
|
||||
@ -1,25 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""GraphRAG indexing error types."""
|
||||
|
||||
|
||||
class NoWorkflowsDefinedError(ValueError):
|
||||
"""Exception for no workflows defined."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("No workflows defined.")
|
||||
|
||||
|
||||
class UndefinedWorkflowError(ValueError):
|
||||
"""Exception for invalid verb input."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Workflow name is undefined.")
|
||||
|
||||
|
||||
class UnknownWorkflowError(ValueError):
|
||||
"""Exception for invalid verb input."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__(f"Unknown workflow: {name}")
|
||||
@ -70,12 +70,12 @@ async def create_final_community_reports(
|
||||
)
|
||||
|
||||
community_reports = await summarize_communities(
|
||||
local_contexts,
|
||||
nodes,
|
||||
community_hierarchy,
|
||||
callbacks,
|
||||
cache,
|
||||
summarization_strategy,
|
||||
local_contexts=local_contexts,
|
||||
nodes=nodes,
|
||||
community_hierarchy=community_hierarchy,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
strategy=summarization_strategy,
|
||||
async_mode=async_mode,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
@ -31,15 +31,15 @@ async def create_final_covariates(
|
||||
# this also results in text_unit_id being copied to the output covariate table
|
||||
text_units["text_unit_id"] = text_units["id"]
|
||||
covariates = await extract_covariates(
|
||||
text_units,
|
||||
callbacks,
|
||||
cache,
|
||||
"text",
|
||||
covariate_type,
|
||||
extraction_strategy,
|
||||
async_mode,
|
||||
entity_types,
|
||||
num_threads,
|
||||
input=text_units,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
column="text",
|
||||
covariate_type=covariate_type,
|
||||
strategy=extraction_strategy,
|
||||
async_mode=async_mode,
|
||||
entity_types=entity_types,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
text_units.drop(columns=["text_unit_id"], inplace=True) # don't pollute the global
|
||||
covariates["id"] = covariates["covariate_type"].apply(lambda _x: str(uuid4()))
|
||||
|
||||
@ -31,9 +31,9 @@ async def extract_graph(
|
||||
"""All the steps to create the base entity graph."""
|
||||
# this returns a graph for each text unit, to be merged later
|
||||
entities, relationships = await extract_entities(
|
||||
text_units,
|
||||
callbacks,
|
||||
cache,
|
||||
text_units=text_units,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
text_column="text",
|
||||
id_column="id",
|
||||
strategy=extraction_strategy,
|
||||
@ -55,10 +55,10 @@ async def extract_graph(
|
||||
raise ValueError(error_msg)
|
||||
|
||||
entity_summaries, relationship_summaries = await summarize_descriptions(
|
||||
entities,
|
||||
relationships,
|
||||
callbacks,
|
||||
cache,
|
||||
entities_df=entities,
|
||||
relationships_df=relationships,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
strategy=summarization_strategy,
|
||||
num_threads=summarization_num_threads,
|
||||
)
|
||||
|
||||
@ -9,7 +9,7 @@ import pandas as pd
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.config.embeddings import (
|
||||
from graphrag.config.embeddings import (
|
||||
community_full_content_embedding,
|
||||
community_summary_embedding,
|
||||
community_title_embedding,
|
||||
@ -119,9 +119,9 @@ async def _run_and_snapshot_embeddings(
|
||||
"""All the steps to generate single embedding."""
|
||||
if text_embed_config:
|
||||
data["embedding"] = await embed_text(
|
||||
data,
|
||||
callbacks,
|
||||
cache,
|
||||
input=data,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
embed_column=embed_column,
|
||||
embedding_name=name,
|
||||
strategy=text_embed_config["strategy"],
|
||||
|
||||
@ -6,11 +6,10 @@
|
||||
import logging
|
||||
import re
|
||||
from io import BytesIO
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.index.config.input import PipelineCSVInputConfig, PipelineInputConfig
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
@ -23,19 +22,18 @@ input_type = "csv"
|
||||
|
||||
|
||||
async def load(
|
||||
config: PipelineInputConfig,
|
||||
config: InputConfig,
|
||||
progress: ProgressLogger | None,
|
||||
storage: PipelineStorage,
|
||||
) -> pd.DataFrame:
|
||||
"""Load csv inputs from a directory."""
|
||||
csv_config = cast("PipelineCSVInputConfig", config)
|
||||
log.info("Loading csv files from %s", csv_config.base_dir)
|
||||
log.info("Loading csv files from %s", config.base_dir)
|
||||
|
||||
async def load_file(path: str, group: dict | None) -> pd.DataFrame:
|
||||
if group is None:
|
||||
group = {}
|
||||
buffer = BytesIO(await storage.get(path, as_bytes=True))
|
||||
data = pd.read_csv(buffer, encoding=config.encoding or "latin-1")
|
||||
data = pd.read_csv(buffer, encoding=config.encoding)
|
||||
additional_keys = group.keys()
|
||||
if len(additional_keys) > 0:
|
||||
data[[*additional_keys]] = data.apply(
|
||||
@ -43,51 +41,49 @@ async def load(
|
||||
)
|
||||
if "id" not in data.columns:
|
||||
data["id"] = data.apply(lambda x: gen_sha512_hash(x, x.keys()), axis=1)
|
||||
if csv_config.source_column is not None and "source" not in data.columns:
|
||||
if csv_config.source_column not in data.columns:
|
||||
if config.source_column is not None and "source" not in data.columns:
|
||||
if config.source_column not in data.columns:
|
||||
log.warning(
|
||||
"source_column %s not found in csv file %s",
|
||||
csv_config.source_column,
|
||||
config.source_column,
|
||||
path,
|
||||
)
|
||||
else:
|
||||
data["source"] = data.apply(
|
||||
lambda x: x[csv_config.source_column], axis=1
|
||||
)
|
||||
if csv_config.text_column is not None and "text" not in data.columns:
|
||||
if csv_config.text_column not in data.columns:
|
||||
data["source"] = data.apply(lambda x: x[config.source_column], axis=1)
|
||||
if config.text_column is not None and "text" not in data.columns:
|
||||
if config.text_column not in data.columns:
|
||||
log.warning(
|
||||
"text_column %s not found in csv file %s",
|
||||
csv_config.text_column,
|
||||
config.text_column,
|
||||
path,
|
||||
)
|
||||
else:
|
||||
data["text"] = data.apply(lambda x: x[csv_config.text_column], axis=1)
|
||||
if csv_config.title_column is not None and "title" not in data.columns:
|
||||
if csv_config.title_column not in data.columns:
|
||||
data["text"] = data.apply(lambda x: x[config.text_column], axis=1)
|
||||
if config.title_column is not None and "title" not in data.columns:
|
||||
if config.title_column not in data.columns:
|
||||
log.warning(
|
||||
"title_column %s not found in csv file %s",
|
||||
csv_config.title_column,
|
||||
config.title_column,
|
||||
path,
|
||||
)
|
||||
else:
|
||||
data["title"] = data.apply(lambda x: x[csv_config.title_column], axis=1)
|
||||
data["title"] = data.apply(lambda x: x[config.title_column], axis=1)
|
||||
|
||||
if csv_config.timestamp_column is not None:
|
||||
fmt = csv_config.timestamp_format
|
||||
if config.timestamp_column is not None:
|
||||
fmt = config.timestamp_format
|
||||
if fmt is None:
|
||||
msg = "Must specify timestamp_format if timestamp_column is specified"
|
||||
raise ValueError(msg)
|
||||
|
||||
if csv_config.timestamp_column not in data.columns:
|
||||
if config.timestamp_column not in data.columns:
|
||||
log.warning(
|
||||
"timestamp_column %s not found in csv file %s",
|
||||
csv_config.timestamp_column,
|
||||
config.timestamp_column,
|
||||
path,
|
||||
)
|
||||
else:
|
||||
data["timestamp"] = pd.to_datetime(
|
||||
data[csv_config.timestamp_column], format=fmt
|
||||
data[config.timestamp_column], format=fmt
|
||||
)
|
||||
|
||||
# TODO: Theres probably a less gross way to do this
|
||||
|
||||
@ -12,7 +12,6 @@ import pandas as pd
|
||||
|
||||
from graphrag.config.enums import InputType
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.config.input import PipelineInputConfig
|
||||
from graphrag.index.input.csv import input_type as csv
|
||||
from graphrag.index.input.csv import load as load_csv
|
||||
from graphrag.index.input.text import input_type as text
|
||||
@ -30,7 +29,7 @@ loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = {
|
||||
|
||||
|
||||
async def create_input(
|
||||
config: PipelineInputConfig | InputConfig,
|
||||
config: InputConfig,
|
||||
progress_reporter: ProgressLogger | None = None,
|
||||
root_dir: str | None = None,
|
||||
) -> pd.DataFrame:
|
||||
|
||||
@ -10,7 +10,7 @@ from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.index.config.input import PipelineInputConfig
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
@ -23,7 +23,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def load(
|
||||
config: PipelineInputConfig,
|
||||
config: InputConfig,
|
||||
progress: ProgressLogger | None,
|
||||
storage: PipelineStorage,
|
||||
) -> pd.DataFrame:
|
||||
|
||||
@ -19,11 +19,12 @@ from fnllm.openai import (
|
||||
create_openai_embeddings_llm,
|
||||
)
|
||||
from fnllm.openai.types.chat.parameters import OpenAIChatParameters
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.enums import LLMType
|
||||
from graphrag.config.models.llm_parameters import LLMParameters
|
||||
from graphrag.config.models.language_model_config import (
|
||||
LanguageModelConfig, # noqa: TC001
|
||||
)
|
||||
from graphrag.index.llm.manager import ChatLLMSingleton, EmbeddingsLLMSingleton
|
||||
|
||||
from .mock_llm import MockChatLLM
|
||||
@ -93,17 +94,9 @@ def create_cache(cache: PipelineCache | None, name: str) -> LLMCache | None:
|
||||
return GraphRagLLMCache(cache).child(name)
|
||||
|
||||
|
||||
def read_llm_params(llm_args: dict[str, Any]) -> LLMParameters:
|
||||
"""Read the LLM parameters from the arguments."""
|
||||
if llm_args == {}:
|
||||
msg = "LLM arguments are required"
|
||||
raise ValueError(msg)
|
||||
return TypeAdapter(LLMParameters).validate_python(llm_args)
|
||||
|
||||
|
||||
def load_llm(
|
||||
name: str,
|
||||
config: LLMParameters,
|
||||
config: LanguageModelConfig,
|
||||
*,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache | None,
|
||||
@ -133,7 +126,7 @@ def load_llm(
|
||||
|
||||
def load_llm_embeddings(
|
||||
name: str,
|
||||
llm_config: LLMParameters,
|
||||
llm_config: LanguageModelConfig,
|
||||
*,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache | None,
|
||||
@ -174,7 +167,7 @@ def _create_error_handler(callbacks: WorkflowCallbacks) -> ErrorHandlerFn:
|
||||
def _load_openai_chat_llm(
|
||||
on_error: ErrorHandlerFn,
|
||||
cache: LLMCache,
|
||||
config: LLMParameters,
|
||||
config: LanguageModelConfig,
|
||||
azure=False,
|
||||
):
|
||||
return _create_openai_chat_llm(
|
||||
@ -187,7 +180,7 @@ def _load_openai_chat_llm(
|
||||
def _load_openai_embeddings_llm(
|
||||
on_error: ErrorHandlerFn,
|
||||
cache: LLMCache,
|
||||
config: LLMParameters,
|
||||
config: LanguageModelConfig,
|
||||
azure=False,
|
||||
):
|
||||
return _create_openai_embeddings_llm(
|
||||
@ -197,8 +190,8 @@ def _load_openai_embeddings_llm(
|
||||
)
|
||||
|
||||
|
||||
def _create_openai_config(config: LLMParameters, azure: bool) -> OpenAIConfig:
|
||||
encoding_model = config.encoding_model or defs.ENCODING_MODEL
|
||||
def _create_openai_config(config: LanguageModelConfig, azure: bool) -> OpenAIConfig:
|
||||
encoding_model = config.encoding_model
|
||||
json_strategy = (
|
||||
JsonStrategy.VALID if config.model_supports_json else JsonStrategy.LOOSE
|
||||
)
|
||||
@ -253,19 +246,19 @@ def _create_openai_config(config: LLMParameters, azure: bool) -> OpenAIConfig:
|
||||
|
||||
|
||||
def _load_azure_openai_chat_llm(
|
||||
on_error: ErrorHandlerFn, cache: LLMCache, config: LLMParameters
|
||||
on_error: ErrorHandlerFn, cache: LLMCache, config: LanguageModelConfig
|
||||
):
|
||||
return _load_openai_chat_llm(on_error, cache, config, True)
|
||||
|
||||
|
||||
def _load_azure_openai_embeddings_llm(
|
||||
on_error: ErrorHandlerFn, cache: LLMCache, config: LLMParameters
|
||||
on_error: ErrorHandlerFn, cache: LLMCache, config: LanguageModelConfig
|
||||
):
|
||||
return _load_openai_embeddings_llm(on_error, cache, config, True)
|
||||
|
||||
|
||||
def _load_static_response(
|
||||
_on_error: ErrorHandlerFn, _cache: PipelineCache, config: LLMParameters
|
||||
_on_error: ErrorHandlerFn, _cache: PipelineCache, config: LanguageModelConfig
|
||||
) -> ChatLLM:
|
||||
if config.responses is None:
|
||||
msg = "Static response LLM requires responses"
|
||||
|
||||
@ -12,8 +12,8 @@ import pandas as pd
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.embeddings import create_collection_name
|
||||
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy
|
||||
from graphrag.utils.embeddings import create_collection_name
|
||||
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
|
||||
from graphrag.vector_stores.factory import VectorStoreFactory
|
||||
|
||||
@ -87,23 +87,23 @@ async def embed_text(
|
||||
embedding_name, vector_store_config
|
||||
)
|
||||
return await _text_embed_with_vector_store(
|
||||
input,
|
||||
callbacks,
|
||||
cache,
|
||||
embed_column,
|
||||
strategy,
|
||||
vector_store,
|
||||
vector_store_workflow_config,
|
||||
input=input,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
embed_column=embed_column,
|
||||
strategy=strategy,
|
||||
vector_store=vector_store,
|
||||
vector_store_config=vector_store_workflow_config,
|
||||
id_column=id_column,
|
||||
title_column=title_column,
|
||||
)
|
||||
|
||||
return await _text_embed_in_memory(
|
||||
input,
|
||||
callbacks,
|
||||
cache,
|
||||
embed_column,
|
||||
strategy,
|
||||
input=input,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
embed_column=embed_column,
|
||||
strategy=strategy,
|
||||
)
|
||||
|
||||
|
||||
@ -176,12 +176,7 @@ async def _text_embed_with_vector_store(
|
||||
texts: list[str] = batch[embed_column].to_numpy().tolist()
|
||||
titles: list[str] = batch[title].to_numpy().tolist()
|
||||
ids: list[str] = batch[id_column].to_numpy().tolist()
|
||||
result = await strategy_exec(
|
||||
texts,
|
||||
callbacks,
|
||||
cache,
|
||||
strategy_args,
|
||||
)
|
||||
result = await strategy_exec(texts, callbacks, cache, strategy_args)
|
||||
if result.embeddings:
|
||||
embeddings = [
|
||||
embedding for embedding in result.embeddings if embedding is not None
|
||||
|
||||
@ -9,12 +9,10 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from fnllm import EmbeddingsLLM
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.llm_parameters import LLMParameters
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.index.llm.load_llm import load_llm_embeddings
|
||||
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingResult
|
||||
from graphrag.index.text_splitting.text_splitting import TokenTextSplitter
|
||||
@ -34,9 +32,10 @@ async def run(
|
||||
if is_null(input):
|
||||
return TextEmbeddingResult(embeddings=None)
|
||||
|
||||
llm_config = TypeAdapter(LLMParameters).validate_python(args.get("llm", {}))
|
||||
batch_size = args.get("batch_size", 16)
|
||||
batch_max_tokens = args.get("batch_max_tokens", 8191)
|
||||
llm_config = args["llm"]
|
||||
llm_config = LanguageModelConfig(**args["llm"])
|
||||
splitter = _get_splitter(llm_config, batch_max_tokens)
|
||||
llm = _get_llm(llm_config, callbacks, cache)
|
||||
semaphore: asyncio.Semaphore = asyncio.Semaphore(args.get("num_threads", 4))
|
||||
@ -66,15 +65,17 @@ async def run(
|
||||
return TextEmbeddingResult(embeddings=embeddings)
|
||||
|
||||
|
||||
def _get_splitter(config: LLMParameters, batch_max_tokens: int) -> TokenTextSplitter:
|
||||
def _get_splitter(
|
||||
config: LanguageModelConfig, batch_max_tokens: int
|
||||
) -> TokenTextSplitter:
|
||||
return TokenTextSplitter(
|
||||
encoding_name=config.encoding_model or defs.ENCODING_MODEL,
|
||||
encoding_name=config.encoding_model,
|
||||
chunk_size=batch_max_tokens,
|
||||
)
|
||||
|
||||
|
||||
def _get_llm(
|
||||
config: LLMParameters,
|
||||
config: LanguageModelConfig,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
) -> EmbeddingsLLM:
|
||||
|
||||
@ -14,7 +14,8 @@ import graphrag.config.defaults as defs
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.index.llm.load_llm import load_llm, read_llm_params
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.index.llm.load_llm import load_llm
|
||||
from graphrag.index.operations.extract_covariates.claim_extractor import ClaimExtractor
|
||||
from graphrag.index.operations.extract_covariates.typing import (
|
||||
Covariate,
|
||||
@ -52,7 +53,12 @@ async def extract_covariates(
|
||||
async def run_strategy(row):
|
||||
text = row[column]
|
||||
result = await run_claim_extraction(
|
||||
text, entity_types, resolved_entities_map, callbacks, cache, strategy_config
|
||||
input=text,
|
||||
entity_types=entity_types,
|
||||
resolved_entities_map=resolved_entities_map,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
strategy_config=strategy_config,
|
||||
)
|
||||
return [
|
||||
create_row_from_claim_data(row, item, covariate_type)
|
||||
@ -83,8 +89,13 @@ async def run_claim_extraction(
|
||||
strategy_config: dict[str, Any],
|
||||
) -> CovariateExtractionResult:
|
||||
"""Run the Claim extraction chain."""
|
||||
llm_config = read_llm_params(strategy_config.get("llm", {}))
|
||||
llm = load_llm("claim_extraction", llm_config, callbacks=callbacks, cache=cache)
|
||||
llm_config = LanguageModelConfig(**strategy_config["llm"])
|
||||
llm = load_llm(
|
||||
"claim_extraction",
|
||||
llm_config,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
)
|
||||
extraction_prompt = strategy_config.get("extraction_prompt")
|
||||
max_gleanings = strategy_config.get("max_gleanings", defs.CLAIM_MAX_GLEANINGS)
|
||||
tuple_delimiter = strategy_config.get("tuple_delimiter")
|
||||
|
||||
@ -9,7 +9,8 @@ from fnllm import ChatLLM
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.llm.load_llm import load_llm, read_llm_params
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.index.llm.load_llm import load_llm
|
||||
from graphrag.index.operations.extract_entities.graph_extractor import GraphExtractor
|
||||
from graphrag.index.operations.extract_entities.typing import (
|
||||
Document,
|
||||
@ -27,8 +28,13 @@ async def run_graph_intelligence(
|
||||
args: StrategyConfig,
|
||||
) -> EntityExtractionResult:
|
||||
"""Run the graph intelligence entity extraction strategy."""
|
||||
llm_config = read_llm_params(args.get("llm", {}))
|
||||
llm = load_llm("entity_extraction", llm_config, callbacks=callbacks, cache=cache)
|
||||
llm_config = LanguageModelConfig(**args["llm"])
|
||||
llm = load_llm(
|
||||
"entity_extraction",
|
||||
llm_config,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
)
|
||||
return await run_extract_entities(llm, docs, entity_types, callbacks, args)
|
||||
|
||||
|
||||
|
||||
@ -10,7 +10,8 @@ from fnllm import ChatLLM
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.llm.load_llm import load_llm, read_llm_params
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.index.llm.load_llm import load_llm
|
||||
from graphrag.index.operations.summarize_communities.community_reports_extractor.community_reports_extractor import (
|
||||
CommunityReportsExtractor,
|
||||
)
|
||||
@ -33,8 +34,13 @@ async def run_graph_intelligence(
|
||||
args: StrategyConfig,
|
||||
) -> CommunityReport | None:
|
||||
"""Run the graph intelligence entity extraction strategy."""
|
||||
llm_config = read_llm_params(args.get("llm", {}))
|
||||
llm = load_llm("community_reporting", llm_config, callbacks=callbacks, cache=cache)
|
||||
llm_config = LanguageModelConfig(**args["llm"])
|
||||
llm = load_llm(
|
||||
"community_reporting",
|
||||
llm_config,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
)
|
||||
return await _run_extractor(llm, community, input, level, args, callbacks)
|
||||
|
||||
|
||||
|
||||
@ -93,7 +93,12 @@ async def _generate_report(
|
||||
) -> CommunityReport | None:
|
||||
"""Generate a report for a single community."""
|
||||
return await runner(
|
||||
community_id, community_context, community_level, callbacks, cache, strategy
|
||||
community_id,
|
||||
community_context,
|
||||
community_level,
|
||||
callbacks,
|
||||
cache,
|
||||
strategy,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -7,7 +7,8 @@ from fnllm import ChatLLM
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.llm.load_llm import load_llm, read_llm_params
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.index.llm.load_llm import load_llm
|
||||
from graphrag.index.operations.summarize_descriptions.description_summary_extractor import (
|
||||
SummarizeExtractor,
|
||||
)
|
||||
@ -25,9 +26,12 @@ async def run_graph_intelligence(
|
||||
args: StrategyConfig,
|
||||
) -> SummarizedDescriptionResult:
|
||||
"""Run the graph intelligence entity extraction strategy."""
|
||||
llm_config = read_llm_params(args.get("llm", {}))
|
||||
llm_config = LanguageModelConfig(**args["llm"])
|
||||
llm = load_llm(
|
||||
"summarize_descriptions", llm_config, callbacks=callbacks, cache=cache
|
||||
"summarize_descriptions",
|
||||
llm_config,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
)
|
||||
return await run_summarize_descriptions(llm, id, descriptions, callbacks, args)
|
||||
|
||||
|
||||
@ -136,11 +136,7 @@ async def summarize_descriptions(
|
||||
):
|
||||
async with semaphore:
|
||||
results = await strategy_exec(
|
||||
id,
|
||||
descriptions,
|
||||
callbacks,
|
||||
cache,
|
||||
strategy_config,
|
||||
id, descriptions, callbacks, cache, strategy_config
|
||||
)
|
||||
ticker(1)
|
||||
return results
|
||||
|
||||
@ -23,10 +23,11 @@ ItemType = TypeVar("ItemType")
|
||||
class ParallelizationError(ValueError):
|
||||
"""Exception for invalid parallel processing."""
|
||||
|
||||
def __init__(self, num_errors: int):
|
||||
super().__init__(
|
||||
f"{num_errors} Errors occurred while running parallel transformation, could not complete!"
|
||||
)
|
||||
def __init__(self, num_errors: int, example: str | None = None):
|
||||
msg = f"{num_errors} Errors occurred while running parallel transformation, could not complete!"
|
||||
if example:
|
||||
msg += f"\nExample error: {example}"
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
async def derive_from_rows(
|
||||
@ -153,6 +154,6 @@ async def _derive_from_rows_base(
|
||||
callbacks.error("parallel transformation error", error, stack)
|
||||
|
||||
if len(errors) > 0:
|
||||
raise ParallelizationError(len(errors))
|
||||
raise ParallelizationError(len(errors), errors[0][1])
|
||||
|
||||
return result
|
||||
|
||||
@ -55,16 +55,14 @@ async def run_workflows(
|
||||
cache: PipelineCache | None = None,
|
||||
callbacks: list[WorkflowCallbacks] | None = None,
|
||||
logger: ProgressLogger | None = None,
|
||||
run_id: str | None = None,
|
||||
is_update_run: bool = False,
|
||||
) -> AsyncIterable[PipelineRunResult]:
|
||||
"""Run all workflows using a simplified pipeline."""
|
||||
run_id = run_id or time.strftime("%Y%m%d-%H%M%S")
|
||||
root_dir = config.root_dir or ""
|
||||
root_dir = config.root_dir
|
||||
progress_logger = logger or NullProgressLogger()
|
||||
callbacks = callbacks or [ConsoleWorkflowCallbacks()]
|
||||
callback_chain = create_callback_chain(callbacks, progress_logger)
|
||||
storage_config = config.storage.model_dump() # type: ignore
|
||||
storage_config = config.output.model_dump() # type: ignore
|
||||
storage = StorageFactory().create_storage(
|
||||
storage_type=storage_config["type"], # type: ignore
|
||||
kwargs=storage_config,
|
||||
@ -81,7 +79,7 @@ async def run_workflows(
|
||||
if is_update_run:
|
||||
progress_logger.info("Running incremental indexing.")
|
||||
|
||||
update_storage_config = config.update_index_storage.model_dump() # type: ignore
|
||||
update_storage_config = config.update_index_output.model_dump() # type: ignore
|
||||
update_index_storage = StorageFactory().create_storage(
|
||||
storage_type=update_storage_config["type"], # type: ignore
|
||||
kwargs=update_storage_config,
|
||||
|
||||
@ -110,8 +110,11 @@ async def _run_entity_summarization(
|
||||
pd.DataFrame
|
||||
The updated entities dataframe with summarized descriptions.
|
||||
"""
|
||||
summarization_llm_settings = config.get_language_model_config(
|
||||
config.summarize_descriptions.model_id
|
||||
)
|
||||
summarization_strategy = config.summarize_descriptions.resolved_strategy(
|
||||
config.root_dir,
|
||||
config.root_dir, summarization_llm_settings
|
||||
)
|
||||
|
||||
# Prepare tasks for async summarization where needed
|
||||
@ -120,7 +123,11 @@ async def _run_entity_summarization(
|
||||
if isinstance(description, list) and len(description) > 1:
|
||||
# Run entity summarization asynchronously
|
||||
result = await run_entity_summarization(
|
||||
row["title"], description, callbacks, cache, summarization_strategy
|
||||
row["title"],
|
||||
description,
|
||||
callbacks,
|
||||
cache,
|
||||
summarization_strategy,
|
||||
)
|
||||
return result.description
|
||||
# Handle case where description is a single-item list or not a list
|
||||
|
||||
@ -10,8 +10,8 @@ import pandas as pd
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.embeddings import get_embedded_fields, get_embedding_settings
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings
|
||||
from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings
|
||||
from graphrag.index.update.communities import (
|
||||
_merge_and_resolve_nodes,
|
||||
@ -150,7 +150,7 @@ async def update_dataframe_outputs(
|
||||
# Generate text embeddings
|
||||
progress_logger.info("Updating Text Embeddings")
|
||||
embedded_fields = get_embedded_fields(config)
|
||||
text_embed = get_embedding_settings(config.embeddings)
|
||||
text_embed = get_embedding_settings(config)
|
||||
await generate_text_embeddings(
|
||||
final_documents=final_documents_df,
|
||||
final_relationships=merged_relationships_df,
|
||||
|
||||
@ -15,9 +15,11 @@ from graphrag.logger.print_progress import ProgressLogger
|
||||
def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) -> None:
|
||||
"""Validate config file for LLM deployment name typos."""
|
||||
# Validate Chat LLM configs
|
||||
# TODO: Replace default_chat_model with a way to select the model
|
||||
default_llm_settings = parameters.get_language_model_config("default_chat_model")
|
||||
llm = load_llm(
|
||||
"test-llm",
|
||||
parameters.llm,
|
||||
name="test-llm",
|
||||
config=default_llm_settings,
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
cache=None,
|
||||
)
|
||||
@ -29,9 +31,12 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
|
||||
sys.exit(1)
|
||||
|
||||
# Validate Embeddings LLM configs
|
||||
embedding_llm_settings = parameters.get_language_model_config(
|
||||
parameters.embeddings.model_id
|
||||
)
|
||||
embed_llm = load_llm_embeddings(
|
||||
"test-embed-llm",
|
||||
parameters.embeddings.llm,
|
||||
name="test-embed-llm",
|
||||
llm_config=embedding_llm_settings,
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
cache=None,
|
||||
)
|
||||
|
||||
@ -33,19 +33,25 @@ async def run_workflow(
|
||||
claims = await load_table_from_storage(
|
||||
"create_final_covariates", context.storage
|
||||
)
|
||||
async_mode = config.community_reports.async_mode
|
||||
num_threads = config.community_reports.parallelization.num_threads
|
||||
summarization_strategy = config.community_reports.resolved_strategy(config.root_dir)
|
||||
|
||||
community_reports_llm_settings = config.get_language_model_config(
|
||||
config.community_reports.model_id
|
||||
)
|
||||
async_mode = community_reports_llm_settings.async_mode
|
||||
num_threads = community_reports_llm_settings.parallelization_num_threads
|
||||
summarization_strategy = config.community_reports.resolved_strategy(
|
||||
config.root_dir, community_reports_llm_settings
|
||||
)
|
||||
|
||||
output = await create_final_community_reports(
|
||||
nodes,
|
||||
edges,
|
||||
entities,
|
||||
communities,
|
||||
claims,
|
||||
callbacks,
|
||||
context.cache,
|
||||
summarization_strategy,
|
||||
nodes_input=nodes,
|
||||
edges_input=edges,
|
||||
entities=entities,
|
||||
communities=communities,
|
||||
claims_input=claims,
|
||||
callbacks=callbacks,
|
||||
cache=context.cache,
|
||||
summarization_strategy=summarization_strategy,
|
||||
async_mode=async_mode,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
@ -26,12 +26,15 @@ async def run_workflow(
|
||||
"create_base_text_units", context.storage
|
||||
)
|
||||
|
||||
claim_extraction_llm_settings = config.get_language_model_config(
|
||||
config.claim_extraction.model_id
|
||||
)
|
||||
extraction_strategy = config.claim_extraction.resolved_strategy(
|
||||
config.root_dir, config.encoding_model
|
||||
config.root_dir, claim_extraction_llm_settings
|
||||
)
|
||||
|
||||
async_mode = config.claim_extraction.async_mode
|
||||
num_threads = config.claim_extraction.parallelization.num_threads
|
||||
async_mode = claim_extraction_llm_settings.async_mode
|
||||
num_threads = claim_extraction_llm_settings.parallelization_num_threads
|
||||
|
||||
output = await create_final_covariates(
|
||||
text_units,
|
||||
|
||||
@ -28,24 +28,28 @@ async def run_workflow(
|
||||
"create_base_text_units", context.storage
|
||||
)
|
||||
|
||||
extraction_strategy = config.entity_extraction.resolved_strategy(
|
||||
config.root_dir, config.encoding_model
|
||||
entity_extraction_llm_settings = config.get_language_model_config(
|
||||
config.entity_extraction.model_id
|
||||
)
|
||||
extraction_num_threads = config.entity_extraction.parallelization.num_threads
|
||||
extraction_async_mode = config.entity_extraction.async_mode
|
||||
extraction_strategy = config.entity_extraction.resolved_strategy(
|
||||
config.root_dir, entity_extraction_llm_settings
|
||||
)
|
||||
extraction_num_threads = entity_extraction_llm_settings.parallelization_num_threads
|
||||
extraction_async_mode = entity_extraction_llm_settings.async_mode
|
||||
entity_types = config.entity_extraction.entity_types
|
||||
|
||||
summarization_llm_settings = config.get_language_model_config(
|
||||
config.summarize_descriptions.model_id
|
||||
)
|
||||
summarization_strategy = config.summarize_descriptions.resolved_strategy(
|
||||
config.root_dir,
|
||||
)
|
||||
summarization_num_threads = (
|
||||
config.summarize_descriptions.parallelization.num_threads
|
||||
config.root_dir, summarization_llm_settings
|
||||
)
|
||||
summarization_num_threads = summarization_llm_settings.parallelization_num_threads
|
||||
|
||||
base_entity_nodes, base_relationship_edges = await extract_graph(
|
||||
text_units,
|
||||
callbacks,
|
||||
context.cache,
|
||||
text_units=text_units,
|
||||
callbacks=callbacks,
|
||||
cache=context.cache,
|
||||
extraction_strategy=extraction_strategy,
|
||||
extraction_num_threads=extraction_num_threads,
|
||||
extraction_async_mode=extraction_async_mode,
|
||||
|
||||
@ -6,8 +6,8 @@
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.embeddings import get_embedded_fields, get_embedding_settings
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.flows.generate_text_embeddings import (
|
||||
generate_text_embeddings,
|
||||
@ -40,7 +40,7 @@ async def run_workflow(
|
||||
)
|
||||
|
||||
embedded_fields = get_embedded_fields(config)
|
||||
text_embed = get_embedding_settings(config.embeddings)
|
||||
text_embed = get_embedding_settings(config)
|
||||
|
||||
await generate_text_embeddings(
|
||||
final_documents=final_documents,
|
||||
|
||||
@ -16,3 +16,4 @@ MAX_TOKEN_COUNT = 2000
|
||||
MIN_CHUNK_SIZE = 200
|
||||
N_SUBSET_MAX = 300
|
||||
MIN_CHUNK_OVERLAP = 0
|
||||
PROMPT_TUNING_MODEL_ID = "default_chat_model"
|
||||
|
||||
@ -6,12 +6,10 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from fnllm import ChatLLM
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.models.llm_parameters import LLMParameters
|
||||
from graphrag.index.input.factory import create_input
|
||||
from graphrag.index.llm.load_llm import load_llm_embeddings
|
||||
from graphrag.index.operations.chunk_text.chunk_text import chunk_text
|
||||
@ -60,8 +58,8 @@ async def load_docs_in_chunks(
|
||||
k: int = K,
|
||||
) -> list[str]:
|
||||
"""Load docs into chunks for generating prompts."""
|
||||
llm_config = TypeAdapter(LLMParameters).validate_python(
|
||||
config.embeddings.resolved_strategy()["llm"]
|
||||
embeddings_llm_settings = config.get_language_model_config(
|
||||
config.embeddings.model_id
|
||||
)
|
||||
|
||||
dataset = await create_input(config.input, logger, root)
|
||||
@ -96,8 +94,8 @@ async def load_docs_in_chunks(
|
||||
msg = "k must be an integer > 0"
|
||||
raise ValueError(msg)
|
||||
embedding_llm = load_llm_embeddings(
|
||||
"prompt_tuning_embeddings",
|
||||
llm_config,
|
||||
name="prompt_tuning_embeddings",
|
||||
llm_config=embeddings_llm_settings,
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
cache=None,
|
||||
)
|
||||
|
||||
@ -45,9 +45,10 @@ def get_local_search_engine(
|
||||
system_prompt: str | None = None,
|
||||
) -> LocalSearch:
|
||||
"""Create a local search engine based on data + configuration."""
|
||||
default_llm_settings = config.get_language_model_config("default_chat_model")
|
||||
llm = get_llm(config)
|
||||
text_embedder = get_text_embedder(config)
|
||||
token_encoder = tiktoken.get_encoding(config.encoding_model)
|
||||
token_encoder = tiktoken.get_encoding(default_llm_settings.encoding_model)
|
||||
|
||||
ls_config = config.local_search
|
||||
|
||||
@ -102,7 +103,11 @@ def get_global_search_engine(
|
||||
general_knowledge_inclusion_prompt: str | None = None,
|
||||
) -> GlobalSearch:
|
||||
"""Create a global search engine based on data + configuration."""
|
||||
token_encoder = tiktoken.get_encoding(config.encoding_model)
|
||||
# TODO: Global search should select model based on config??
|
||||
default_llm_settings = config.get_language_model_config("default_chat_model")
|
||||
|
||||
# Here we get encoding based on specified encoding name
|
||||
token_encoder = tiktoken.get_encoding(default_llm_settings.encoding_model)
|
||||
gs_config = config.global_search
|
||||
|
||||
dynamic_community_selection_kwargs = {}
|
||||
@ -111,7 +116,8 @@ def get_global_search_engine(
|
||||
|
||||
dynamic_community_selection_kwargs.update({
|
||||
"llm": get_llm(config),
|
||||
"token_encoder": tiktoken.encoding_for_model(config.llm.model),
|
||||
# And here we get encoding based on model
|
||||
"token_encoder": tiktoken.encoding_for_model(default_llm_settings.model),
|
||||
"keep_parent": gs_config.dynamic_search_keep_parent,
|
||||
"num_repeats": gs_config.dynamic_search_num_repeats,
|
||||
"use_summary": gs_config.dynamic_search_use_summary,
|
||||
@ -178,9 +184,10 @@ def get_drift_search_engine(
|
||||
reduce_system_prompt: str | None = None,
|
||||
) -> DRIFTSearch:
|
||||
"""Create a local search engine based on data + configuration."""
|
||||
default_llm_settings = config.get_language_model_config("default_chat_model")
|
||||
llm = get_llm(config)
|
||||
text_embedder = get_text_embedder(config)
|
||||
token_encoder = tiktoken.get_encoding(config.encoding_model)
|
||||
token_encoder = tiktoken.get_encoding(default_llm_settings.encoding_model)
|
||||
|
||||
return DRIFTSearch(
|
||||
llm=llm,
|
||||
@ -208,9 +215,10 @@ def get_basic_search_engine(
|
||||
system_prompt: str | None = None,
|
||||
) -> BasicSearch:
|
||||
"""Create a basic search engine based on data + configuration."""
|
||||
default_llm_settings = config.get_language_model_config("default_chat_model")
|
||||
llm = get_llm(config)
|
||||
text_embedder = get_text_embedder(config)
|
||||
token_encoder = tiktoken.get_encoding(config.encoding_model)
|
||||
token_encoder = tiktoken.get_encoding(default_llm_settings.encoding_model)
|
||||
|
||||
ls_config = config.basic_search
|
||||
|
||||
|
||||
@ -14,61 +14,65 @@ from graphrag.query.llm.oai.typing import OpenaiApiType
|
||||
|
||||
def get_llm(config: GraphRagConfig) -> ChatOpenAI:
|
||||
"""Get the LLM client."""
|
||||
is_azure_client = config.llm.type == LLMType.AzureOpenAIChat
|
||||
debug_llm_key = config.llm.api_key or ""
|
||||
default_llm_settings = config.get_language_model_config("default_chat_model")
|
||||
is_azure_client = default_llm_settings.type == LLMType.AzureOpenAIChat
|
||||
debug_llm_key = default_llm_settings.api_key or ""
|
||||
llm_debug_info = {
|
||||
**config.llm.model_dump(),
|
||||
**default_llm_settings.model_dump(),
|
||||
"api_key": f"REDACTED,len={len(debug_llm_key)}",
|
||||
}
|
||||
audience = (
|
||||
config.llm.audience
|
||||
if config.llm.audience
|
||||
default_llm_settings.audience
|
||||
if default_llm_settings.audience
|
||||
else "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
print(f"creating llm client with {llm_debug_info}") # noqa T201
|
||||
return ChatOpenAI(
|
||||
api_key=config.llm.api_key,
|
||||
api_key=default_llm_settings.api_key,
|
||||
azure_ad_token_provider=(
|
||||
get_bearer_token_provider(DefaultAzureCredential(), audience)
|
||||
if is_azure_client and not config.llm.api_key
|
||||
if is_azure_client and not default_llm_settings.api_key
|
||||
else None
|
||||
),
|
||||
api_base=config.llm.api_base,
|
||||
organization=config.llm.organization,
|
||||
model=config.llm.model,
|
||||
api_base=default_llm_settings.api_base,
|
||||
organization=default_llm_settings.organization,
|
||||
model=default_llm_settings.model,
|
||||
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
|
||||
deployment_name=config.llm.deployment_name,
|
||||
api_version=config.llm.api_version,
|
||||
max_retries=config.llm.max_retries,
|
||||
request_timeout=config.llm.request_timeout,
|
||||
deployment_name=default_llm_settings.deployment_name,
|
||||
api_version=default_llm_settings.api_version,
|
||||
max_retries=default_llm_settings.max_retries,
|
||||
request_timeout=default_llm_settings.request_timeout,
|
||||
)
|
||||
|
||||
|
||||
def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
|
||||
"""Get the LLM client for embeddings."""
|
||||
is_azure_client = config.embeddings.llm.type == LLMType.AzureOpenAIEmbedding
|
||||
debug_embedding_api_key = config.embeddings.llm.api_key or ""
|
||||
embeddings_llm_settings = config.get_language_model_config(
|
||||
config.embeddings.model_id
|
||||
)
|
||||
is_azure_client = embeddings_llm_settings.type == LLMType.AzureOpenAIEmbedding
|
||||
debug_embedding_api_key = embeddings_llm_settings.api_key or ""
|
||||
llm_debug_info = {
|
||||
**config.embeddings.llm.model_dump(),
|
||||
**embeddings_llm_settings.model_dump(),
|
||||
"api_key": f"REDACTED,len={len(debug_embedding_api_key)}",
|
||||
}
|
||||
if config.embeddings.llm.audience is None:
|
||||
if embeddings_llm_settings.audience is None:
|
||||
audience = "https://cognitiveservices.azure.com/.default"
|
||||
else:
|
||||
audience = config.embeddings.llm.audience
|
||||
audience = embeddings_llm_settings.audience
|
||||
print(f"creating embedding llm client with {llm_debug_info}") # noqa T201
|
||||
return OpenAIEmbedding(
|
||||
api_key=config.embeddings.llm.api_key,
|
||||
api_key=embeddings_llm_settings.api_key,
|
||||
azure_ad_token_provider=(
|
||||
get_bearer_token_provider(DefaultAzureCredential(), audience)
|
||||
if is_azure_client and not config.embeddings.llm.api_key
|
||||
if is_azure_client and not embeddings_llm_settings.api_key
|
||||
else None
|
||||
),
|
||||
api_base=config.embeddings.llm.api_base,
|
||||
organization=config.llm.organization,
|
||||
api_base=embeddings_llm_settings.api_base,
|
||||
organization=embeddings_llm_settings.organization,
|
||||
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
|
||||
model=config.embeddings.llm.model,
|
||||
deployment_name=config.embeddings.llm.deployment_name,
|
||||
api_version=config.embeddings.llm.api_version,
|
||||
max_retries=config.embeddings.llm.max_retries,
|
||||
model=embeddings_llm_settings.model,
|
||||
deployment_name=embeddings_llm_settings.deployment_name,
|
||||
api_version=embeddings_llm_settings.api_version,
|
||||
max_retries=embeddings_llm_settings.max_retries,
|
||||
)
|
||||
|
||||
@ -11,7 +11,6 @@ from typing import Any
|
||||
import tiktoken
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
|
||||
from graphrag.query.context_builder.conversation_history import ConversationHistory
|
||||
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
|
||||
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
|
||||
@ -35,7 +34,6 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
self,
|
||||
llm: ChatOpenAI,
|
||||
context_builder: DRIFTSearchContextBuilder,
|
||||
config: DRIFTSearchConfig | None = None,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
query_state: QueryState | None = None,
|
||||
):
|
||||
@ -51,12 +49,13 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
"""
|
||||
super().__init__(llm, context_builder, token_encoder)
|
||||
|
||||
self.config = config or DRIFTSearchConfig()
|
||||
self.context_builder = context_builder
|
||||
self.token_encoder = token_encoder
|
||||
self.query_state = query_state or QueryState()
|
||||
self.primer = DRIFTPrimer(
|
||||
config=self.config, chat_llm=llm, token_encoder=token_encoder
|
||||
config=self.context_builder.config,
|
||||
chat_llm=llm,
|
||||
token_encoder=token_encoder,
|
||||
)
|
||||
self.local_search = self.init_local_search()
|
||||
|
||||
@ -69,21 +68,21 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
LocalSearch: An instance of the LocalSearch class with the configured parameters.
|
||||
"""
|
||||
local_context_params = {
|
||||
"text_unit_prop": self.config.local_search_text_unit_prop,
|
||||
"community_prop": self.config.local_search_community_prop,
|
||||
"top_k_mapped_entities": self.config.local_search_top_k_mapped_entities,
|
||||
"top_k_relationships": self.config.local_search_top_k_relationships,
|
||||
"text_unit_prop": self.context_builder.config.local_search_text_unit_prop,
|
||||
"community_prop": self.context_builder.config.local_search_community_prop,
|
||||
"top_k_mapped_entities": self.context_builder.config.local_search_top_k_mapped_entities,
|
||||
"top_k_relationships": self.context_builder.config.local_search_top_k_relationships,
|
||||
"include_entity_rank": True,
|
||||
"include_relationship_weight": True,
|
||||
"include_community_rank": False,
|
||||
"return_candidate_context": False,
|
||||
"embedding_vectorstore_key": EntityVectorStoreKey.ID,
|
||||
"max_tokens": self.config.local_search_max_data_tokens,
|
||||
"max_tokens": self.context_builder.config.local_search_max_data_tokens,
|
||||
}
|
||||
|
||||
llm_params = {
|
||||
"max_tokens": self.config.local_search_llm_max_gen_tokens,
|
||||
"temperature": self.config.local_search_temperature,
|
||||
"max_tokens": self.context_builder.config.local_search_llm_max_gen_tokens,
|
||||
"temperature": self.context_builder.config.local_search_temperature,
|
||||
"response_format": {"type": "json_object"},
|
||||
}
|
||||
|
||||
@ -220,13 +219,15 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
# Main loop
|
||||
epochs = 0
|
||||
llm_call_offset = 0
|
||||
while epochs < self.config.n:
|
||||
while epochs < self.context_builder.config.n:
|
||||
actions = self.query_state.rank_incomplete_actions()
|
||||
if len(actions) == 0:
|
||||
log.info("No more actions to take. Exiting DRIFT loop.")
|
||||
break
|
||||
actions = actions[: self.config.drift_k_followups]
|
||||
llm_call_offset += len(actions) - self.config.drift_k_followups
|
||||
actions = actions[: self.context_builder.config.drift_k_followups]
|
||||
llm_call_offset += (
|
||||
len(actions) - self.context_builder.config.drift_k_followups
|
||||
)
|
||||
# Process actions
|
||||
results = await self.asearch_step(
|
||||
global_query=query, search_engine=self.local_search, actions=actions
|
||||
@ -260,8 +261,8 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
llm_calls=llm_calls,
|
||||
prompt_tokens=prompt_tokens,
|
||||
output_tokens=output_tokens,
|
||||
max_tokens=self.config.reduce_max_tokens,
|
||||
temperature=self.config.reduce_temperature,
|
||||
max_tokens=self.context_builder.config.reduce_max_tokens,
|
||||
temperature=self.context_builder.config.reduce_temperature,
|
||||
)
|
||||
|
||||
return SearchResult(
|
||||
@ -317,8 +318,8 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
async for resp in self._reduce_response_streaming(
|
||||
responses=result.response,
|
||||
query=query,
|
||||
max_tokens=self.config.reduce_max_tokens,
|
||||
temperature=self.config.reduce_temperature,
|
||||
max_tokens=self.context_builder.config.reduce_max_tokens,
|
||||
temperature=self.context_builder.config.reduce_temperature,
|
||||
):
|
||||
yield resp
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from graphrag.config.enums import StorageType
|
||||
from graphrag.config.enums import OutputType
|
||||
from graphrag.storage.blob_pipeline_storage import create_blob_storage
|
||||
from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage
|
||||
from graphrag.storage.file_pipeline_storage import create_file_storage
|
||||
@ -35,17 +35,17 @@ class StorageFactory:
|
||||
|
||||
@classmethod
|
||||
def create_storage(
|
||||
cls, storage_type: StorageType | str, kwargs: dict
|
||||
cls, storage_type: OutputType | str, kwargs: dict
|
||||
) -> PipelineStorage:
|
||||
"""Create or get a storage object from the provided type."""
|
||||
match storage_type:
|
||||
case StorageType.blob:
|
||||
case OutputType.blob:
|
||||
return create_blob_storage(**kwargs)
|
||||
case StorageType.cosmosdb:
|
||||
case OutputType.cosmosdb:
|
||||
return create_cosmosdb_storage(**kwargs)
|
||||
case StorageType.file:
|
||||
case OutputType.file:
|
||||
return create_file_storage(**kwargs)
|
||||
case StorageType.memory:
|
||||
case OutputType.memory:
|
||||
return MemoryPipelineStorage()
|
||||
case _:
|
||||
if storage_type in cls.storage_types:
|
||||
|
||||
@ -1,25 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Utilities for working with embeddings stores."""
|
||||
|
||||
from graphrag.index.config.embeddings import all_embeddings
|
||||
|
||||
|
||||
def create_collection_name(
|
||||
container_name: str, embedding_name: str, validate: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Create a collection name for the embedding store.
|
||||
|
||||
Within any given vector store, we can have multiple sets of embeddings organized into projects.
|
||||
The `container` param is used for this partitioning, and is added as a prefix to the collection name for differentiation.
|
||||
|
||||
The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings
|
||||
|
||||
Note that we use dot notation in our names, but many vector stores do not support this - so we convert to dashes.
|
||||
"""
|
||||
if validate and embedding_name not in all_embeddings:
|
||||
msg = f"Invalid embedding name: {embedding_name}"
|
||||
raise KeyError(msg)
|
||||
return f"{container_name}-{embedding_name}".replace(".", "-")
|
||||
11
tests/fixtures/min-csv/config.json
vendored
11
tests/fixtures/min-csv/config.json
vendored
@ -7,7 +7,6 @@
|
||||
1,
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 0
|
||||
},
|
||||
@ -16,7 +15,6 @@
|
||||
1,
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 0
|
||||
},
|
||||
@ -25,7 +23,6 @@
|
||||
1,
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300,
|
||||
"expected_artifacts": 0
|
||||
},
|
||||
@ -38,7 +35,6 @@
|
||||
"type",
|
||||
"description"
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
@ -47,7 +43,6 @@
|
||||
1,
|
||||
6000
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
@ -61,7 +56,6 @@
|
||||
"x",
|
||||
"y"
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
@ -70,7 +64,6 @@
|
||||
1,
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
@ -90,7 +83,6 @@
|
||||
"period",
|
||||
"size"
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
@ -103,7 +95,6 @@
|
||||
"relationship_ids",
|
||||
"entity_ids"
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
@ -112,7 +103,6 @@
|
||||
1,
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
@ -121,7 +111,6 @@
|
||||
1,
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
}
|
||||
|
||||
52
tests/fixtures/min-csv/settings.yml
vendored
52
tests/fixtures/min-csv/settings.yml
vendored
@ -1,24 +1,50 @@
|
||||
models:
|
||||
default_chat_model:
|
||||
type: ${GRAPHRAG_LLM_TYPE}
|
||||
api_key: ${GRAPHRAG_API_KEY}
|
||||
api_base: ${GRAPHRAG_API_BASE}
|
||||
api_version: ${GRAPHRAG_API_VERSION}
|
||||
deployment_name: ${GRAPHRAG_LLM_DEPLOYMENT_NAME}
|
||||
model: ${GRAPHRAG_LLM_MODEL}
|
||||
tokens_per_minute: ${GRAPHRAG_LLM_TPM}
|
||||
requests_per_minute: ${GRAPHRAG_LLM_RPM}
|
||||
model_supports_json: true
|
||||
parallelization_num_threads: 50
|
||||
parallelization_stagger: 0.3
|
||||
async_mode: threaded
|
||||
default_embedding_model:
|
||||
type: ${GRAPHRAG_EMBEDDING_TYPE}
|
||||
api_key: ${GRAPHRAG_API_KEY}
|
||||
api_base: ${GRAPHRAG_API_BASE}
|
||||
api_version: ${GRAPHRAG_API_VERSION}
|
||||
deployment_name: ${GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME}
|
||||
model: ${GRAPHRAG_EMBEDDING_MODEL}
|
||||
tokens_per_minute: ${GRAPHRAG_EMBEDDING_TPM}
|
||||
requests_per_minute: ${GRAPHRAG_EMBEDDING_RPM}
|
||||
parallelization_num_threads: 50
|
||||
parallelization_stagger: 0.3
|
||||
async_mode: threaded
|
||||
|
||||
vector_store:
|
||||
type: "lancedb"
|
||||
db_uri: "./tests/fixtures/min-csv/lancedb"
|
||||
container_name: "lancedb_ci"
|
||||
overwrite: True
|
||||
|
||||
input:
|
||||
file_type: csv
|
||||
file_pattern: ".*\\.csv$$"
|
||||
|
||||
embeddings:
|
||||
vector_store:
|
||||
type: "lancedb"
|
||||
db_uri: "./tests/fixtures/min-csv/lancedb"
|
||||
container_name: "lancedb_ci"
|
||||
overwrite: True
|
||||
model_id: "default_embedding_model"
|
||||
|
||||
storage:
|
||||
type: file # or blob
|
||||
base_dir: "output/${timestamp}/artifacts"
|
||||
# connection_string: <azure_blob_storage_connection_string>
|
||||
# container_name: <azure_blob_storage_container_name>
|
||||
type: file
|
||||
base_dir: "output"
|
||||
|
||||
reporting:
|
||||
type: file # or console, blob
|
||||
base_dir: "output/${timestamp}/reports"
|
||||
# connection_string: <azure_blob_storage_connection_string>
|
||||
# container_name: <azure_blob_storage_container_name>
|
||||
type: file
|
||||
base_dir: "logs"
|
||||
|
||||
snapshots:
|
||||
embeddings: True
|
||||
12
tests/fixtures/text/config.json
vendored
12
tests/fixtures/text/config.json
vendored
@ -7,7 +7,6 @@
|
||||
1,
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 0
|
||||
},
|
||||
@ -16,7 +15,6 @@
|
||||
1,
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 0
|
||||
},
|
||||
@ -25,7 +23,6 @@
|
||||
1,
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300,
|
||||
"expected_artifacts": 0
|
||||
},
|
||||
@ -43,7 +40,6 @@
|
||||
"end_date",
|
||||
"source_text"
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
@ -56,7 +52,6 @@
|
||||
"type",
|
||||
"description"
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
@ -65,7 +60,6 @@
|
||||
1,
|
||||
6000
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
@ -79,7 +73,6 @@
|
||||
"x",
|
||||
"y"
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
@ -88,7 +81,6 @@
|
||||
1,
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
@ -108,7 +100,6 @@
|
||||
"period",
|
||||
"size"
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
@ -121,7 +112,6 @@
|
||||
"relationship_ids",
|
||||
"entity_ids"
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
@ -130,7 +120,6 @@
|
||||
1,
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
@ -139,7 +128,6 @@
|
||||
1,
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
}
|
||||
|
||||
51
tests/fixtures/text/settings.yml
vendored
51
tests/fixtures/text/settings.yml
vendored
@ -1,12 +1,41 @@
|
||||
models:
|
||||
default_chat_model:
|
||||
type: ${GRAPHRAG_LLM_TYPE}
|
||||
api_key: ${GRAPHRAG_API_KEY}
|
||||
api_base: ${GRAPHRAG_API_BASE}
|
||||
api_version: ${GRAPHRAG_API_VERSION}
|
||||
deployment_name: ${GRAPHRAG_LLM_DEPLOYMENT_NAME}
|
||||
model: ${GRAPHRAG_LLM_MODEL}
|
||||
tokens_per_minute: ${GRAPHRAG_LLM_TPM}
|
||||
requests_per_minute: ${GRAPHRAG_LLM_RPM}
|
||||
model_supports_json: true
|
||||
parallelization_num_threads: 50
|
||||
parallelization_stagger: 0.3
|
||||
async_mode: threaded
|
||||
default_embedding_model:
|
||||
type: ${GRAPHRAG_EMBEDDING_TYPE}
|
||||
api_key: ${GRAPHRAG_API_KEY}
|
||||
api_base: ${GRAPHRAG_API_BASE}
|
||||
api_version: ${GRAPHRAG_API_VERSION}
|
||||
deployment_name: ${GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME}
|
||||
model: ${GRAPHRAG_EMBEDDING_MODEL}
|
||||
tokens_per_minute: ${GRAPHRAG_EMBEDDING_TPM}
|
||||
requests_per_minute: ${GRAPHRAG_EMBEDDING_RPM}
|
||||
parallelization_num_threads: 50
|
||||
parallelization_stagger: 0.3
|
||||
async_mode: threaded
|
||||
|
||||
vector_store:
|
||||
type: "azure_ai_search"
|
||||
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
|
||||
api_key: ${AZURE_AI_SEARCH_API_KEY}
|
||||
container_name: "simple_text_ci"
|
||||
|
||||
claim_extraction:
|
||||
enabled: true
|
||||
|
||||
embeddings:
|
||||
vector_store:
|
||||
type: "azure_ai_search"
|
||||
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
|
||||
api_key: ${AZURE_AI_SEARCH_API_KEY}
|
||||
container_name: "simple_text_ci"
|
||||
model_id: "default_embedding_model"
|
||||
|
||||
community_reports:
|
||||
prompt: "prompts/community_report.txt"
|
||||
@ -15,16 +44,12 @@ community_reports:
|
||||
|
||||
|
||||
storage:
|
||||
type: file # or blob
|
||||
base_dir: "output/${timestamp}/artifacts"
|
||||
# connection_string: <azure_blob_storage_connection_string>
|
||||
# container_name: <azure_blob_storage_container_name>
|
||||
type: file
|
||||
base_dir: "output"
|
||||
|
||||
reporting:
|
||||
type: file # or console, blob
|
||||
base_dir: "output/${timestamp}/reports"
|
||||
# connection_string: <azure_blob_storage_connection_string>
|
||||
# container_name: <azure_blob_storage_container_name>
|
||||
type: file
|
||||
base_dir: "logs"
|
||||
|
||||
snapshots:
|
||||
embeddings: True
|
||||
@ -9,7 +9,7 @@ import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.config.enums import StorageType
|
||||
from graphrag.config.enums import OutputType
|
||||
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
||||
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
|
||||
from graphrag.storage.factory import StorageFactory
|
||||
@ -29,7 +29,7 @@ def test_create_blob_storage():
|
||||
"base_dir": "testbasedir",
|
||||
"container_name": "testcontainer",
|
||||
}
|
||||
storage = StorageFactory.create_storage(StorageType.blob, kwargs)
|
||||
storage = StorageFactory.create_storage(OutputType.blob, kwargs)
|
||||
assert isinstance(storage, BlobPipelineStorage)
|
||||
|
||||
|
||||
@ -44,19 +44,19 @@ def test_create_cosmosdb_storage():
|
||||
"base_dir": "testdatabase",
|
||||
"container_name": "testcontainer",
|
||||
}
|
||||
storage = StorageFactory.create_storage(StorageType.cosmosdb, kwargs)
|
||||
storage = StorageFactory.create_storage(OutputType.cosmosdb, kwargs)
|
||||
assert isinstance(storage, CosmosDBPipelineStorage)
|
||||
|
||||
|
||||
def test_create_file_storage():
|
||||
kwargs = {"type": "file", "base_dir": "/tmp/teststorage"}
|
||||
storage = StorageFactory.create_storage(StorageType.file, kwargs)
|
||||
storage = StorageFactory.create_storage(OutputType.file, kwargs)
|
||||
assert isinstance(storage, FilePipelineStorage)
|
||||
|
||||
|
||||
def test_create_memory_storage():
|
||||
kwargs = {"type": "memory"}
|
||||
storage = StorageFactory.create_storage(StorageType.memory, kwargs)
|
||||
storage = StorageFactory.create_storage(OutputType.memory, kwargs)
|
||||
assert isinstance(storage, MemoryPipelineStorage)
|
||||
|
||||
|
||||
|
||||
@ -152,24 +152,12 @@ class TestIndexer:
|
||||
def __assert_indexer_outputs(
|
||||
self, root: Path, workflow_config: dict[str, dict[str, Any]]
|
||||
):
|
||||
outputs_path = root / "output"
|
||||
output_entries = list(outputs_path.iterdir())
|
||||
# Sort the output folders by creation time, most recent
|
||||
output_entries.sort(key=lambda entry: entry.stat().st_ctime, reverse=True)
|
||||
output_path = root / "output"
|
||||
|
||||
if not debug:
|
||||
assert len(output_entries) == 1, (
|
||||
f"Expected one output folder, found {len(output_entries)}"
|
||||
)
|
||||
|
||||
output_path = output_entries[0]
|
||||
assert output_path.exists(), "output folder does not exist"
|
||||
|
||||
artifacts = output_path / "artifacts"
|
||||
assert artifacts.exists(), "artifact folder does not exist"
|
||||
|
||||
# Check stats for all workflow
|
||||
stats = json.loads((artifacts / "stats.json").read_bytes().decode("utf-8"))
|
||||
stats = json.loads((output_path / "stats.json").read_bytes().decode("utf-8"))
|
||||
|
||||
# Check all workflows run
|
||||
expected_artifacts = 0
|
||||
@ -193,7 +181,7 @@ class TestIndexer:
|
||||
)
|
||||
|
||||
# Check artifacts
|
||||
artifact_files = os.listdir(artifacts)
|
||||
artifact_files = os.listdir(output_path)
|
||||
|
||||
# check that the number of workflows matches the number of artifacts
|
||||
assert len(artifact_files) == (expected_artifacts + 3), (
|
||||
@ -202,7 +190,7 @@ class TestIndexer:
|
||||
|
||||
for artifact in artifact_files:
|
||||
if artifact.endswith(".parquet"):
|
||||
output_df = pd.read_parquet(artifacts / artifact)
|
||||
output_df = pd.read_parquet(output_path / artifact)
|
||||
artifact_name = artifact.split(".")[0]
|
||||
|
||||
try:
|
||||
|
||||
9
tests/unit/config/fixtures/minimal_config/settings.yaml
Normal file
9
tests/unit/config/fixtures/minimal_config/settings.yaml
Normal file
@ -0,0 +1,9 @@
|
||||
models:
|
||||
default_chat_model:
|
||||
api_key: ${CUSTOM_API_KEY}
|
||||
type: openai_chat
|
||||
model: gpt-4-turbo-preview
|
||||
default_embedding_model:
|
||||
api_key: ${CUSTOM_API_KEY}
|
||||
type: openai_embedding
|
||||
model: text-embedding-3-small
|
||||
@ -0,0 +1,9 @@
|
||||
models:
|
||||
default_chat_model:
|
||||
api_key: ${SOME_NON_EXISTENT_ENV_VAR}
|
||||
type: openai_chat
|
||||
model: gpt-4-turbo-preview
|
||||
default_embedding_model:
|
||||
api_key: ${SOME_NON_EXISTENT_ENV_VAR}
|
||||
type: openai_embedding
|
||||
model: text-embedding-3-small
|
||||
168
tests/unit/config/test_config.py
Normal file
168
tests/unit/config/test_config.py
Normal file
@ -0,0 +1,168 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.config.enums import AzureAuthType, LLMType
|
||||
from graphrag.config.load_config import load_config
|
||||
from tests.unit.config.utils import (
|
||||
DEFAULT_EMBEDDING_MODEL_CONFIG,
|
||||
DEFAULT_MODEL_CONFIG,
|
||||
FAKE_API_KEY,
|
||||
assert_graphrag_configs,
|
||||
get_default_graphrag_config,
|
||||
)
|
||||
|
||||
|
||||
def test_missing_openai_required_api_key() -> None:
|
||||
model_config_missing_api_key = {
|
||||
defs.DEFAULT_CHAT_MODEL_ID: {
|
||||
"type": LLMType.OpenAIChat,
|
||||
"model": defs.LLM_MODEL,
|
||||
},
|
||||
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
|
||||
}
|
||||
|
||||
# API Key required for OpenAIChat
|
||||
with pytest.raises(ValidationError):
|
||||
create_graphrag_config({"models": model_config_missing_api_key})
|
||||
|
||||
# API Key required for OpenAIEmbedding
|
||||
model_config_missing_api_key[defs.DEFAULT_CHAT_MODEL_ID]["type"] = (
|
||||
LLMType.OpenAIEmbedding
|
||||
)
|
||||
with pytest.raises(ValidationError):
|
||||
create_graphrag_config({"models": model_config_missing_api_key})
|
||||
|
||||
|
||||
def test_missing_azure_api_key() -> None:
|
||||
model_config_missing_api_key = {
|
||||
defs.DEFAULT_CHAT_MODEL_ID: {
|
||||
"type": LLMType.AzureOpenAIChat,
|
||||
"azure_auth_type": AzureAuthType.APIKey,
|
||||
"model": defs.LLM_MODEL,
|
||||
"api_base": "some_api_base",
|
||||
"api_version": "some_api_version",
|
||||
"deployment_name": "some_deployment_name",
|
||||
},
|
||||
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
create_graphrag_config({"models": model_config_missing_api_key})
|
||||
|
||||
# API Key not required for managed identity
|
||||
model_config_missing_api_key[defs.DEFAULT_CHAT_MODEL_ID]["azure_auth_type"] = (
|
||||
AzureAuthType.ManagedIdentity
|
||||
)
|
||||
create_graphrag_config({"models": model_config_missing_api_key})
|
||||
|
||||
|
||||
def test_conflicting_azure_api_key() -> None:
|
||||
model_config_conflicting_api_key = {
|
||||
defs.DEFAULT_CHAT_MODEL_ID: {
|
||||
"type": LLMType.AzureOpenAIChat,
|
||||
"azure_auth_type": AzureAuthType.ManagedIdentity,
|
||||
"model": defs.LLM_MODEL,
|
||||
"api_base": "some_api_base",
|
||||
"api_version": "some_api_version",
|
||||
"deployment_name": "some_deployment_name",
|
||||
"api_key": "THIS_SHOULD_NOT_BE_SET_WHEN_USING_MANAGED_IDENTITY",
|
||||
},
|
||||
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
create_graphrag_config({"models": model_config_conflicting_api_key})
|
||||
|
||||
|
||||
base_azure_model_config = {
|
||||
"type": LLMType.AzureOpenAIChat,
|
||||
"azure_auth_type": AzureAuthType.ManagedIdentity,
|
||||
"model": defs.LLM_MODEL,
|
||||
"api_base": "some_api_base",
|
||||
"api_version": "some_api_version",
|
||||
"deployment_name": "some_deployment_name",
|
||||
}
|
||||
|
||||
|
||||
def test_missing_azure_api_base() -> None:
|
||||
missing_api_base_config = base_azure_model_config.copy()
|
||||
del missing_api_base_config["api_base"]
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
create_graphrag_config({
|
||||
"models": {
|
||||
defs.DEFAULT_CHAT_MODEL_ID: missing_api_base_config,
|
||||
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
def test_missing_azure_api_version() -> None:
|
||||
missing_api_version_config = base_azure_model_config.copy()
|
||||
del missing_api_version_config["api_version"]
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
create_graphrag_config({
|
||||
"models": {
|
||||
defs.DEFAULT_CHAT_MODEL_ID: missing_api_version_config,
|
||||
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
def test_missing_azure_deployment_name() -> None:
|
||||
missing_deployment_name_config = base_azure_model_config.copy()
|
||||
del missing_deployment_name_config["deployment_name"]
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
create_graphrag_config({
|
||||
"models": {
|
||||
defs.DEFAULT_CHAT_MODEL_ID: missing_deployment_name_config,
|
||||
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
def test_default_config() -> None:
|
||||
expected = get_default_graphrag_config()
|
||||
actual = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
assert_graphrag_configs(actual, expected)
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"CUSTOM_API_KEY": FAKE_API_KEY}, clear=True)
|
||||
def test_load_minimal_config() -> None:
|
||||
cwd = Path(__file__).parent
|
||||
root_dir = (cwd / "fixtures" / "minimal_config").resolve()
|
||||
expected = get_default_graphrag_config(str(root_dir))
|
||||
actual = load_config(root_dir=root_dir)
|
||||
assert_graphrag_configs(actual, expected)
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"CUSTOM_API_KEY": FAKE_API_KEY}, clear=True)
|
||||
def test_load_config_with_cli_overrides() -> None:
|
||||
cwd = Path(__file__).parent
|
||||
root_dir = (cwd / "fixtures" / "minimal_config").resolve()
|
||||
output_dir = "some_output_dir"
|
||||
expected_output_base_dir = root_dir / output_dir
|
||||
expected = get_default_graphrag_config(str(root_dir))
|
||||
expected.output.base_dir = str(expected_output_base_dir)
|
||||
actual = load_config(
|
||||
root_dir=root_dir, cli_overrides={"output.base_dir": output_dir}
|
||||
)
|
||||
assert_graphrag_configs(actual, expected)
|
||||
|
||||
|
||||
def test_load_config_missing_env_vars() -> None:
|
||||
cwd = Path(__file__).parent
|
||||
root_dir = (cwd / "fixtures" / "minimal_config_missing_env_var").resolve()
|
||||
with pytest.raises(KeyError):
|
||||
load_config(root_dir=root_dir)
|
||||
@ -1,556 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.config.errors import (
|
||||
ApiKeyMissingError,
|
||||
AzureApiBaseMissingError,
|
||||
AzureDeploymentNameMissingError,
|
||||
)
|
||||
from graphrag.config.models.basic_search_config import BasicSearchConfig
|
||||
from graphrag.config.models.cache_config import CacheConfig
|
||||
from graphrag.config.models.chunking_config import ChunkingConfig
|
||||
from graphrag.config.models.claim_extraction_config import ClaimExtractionConfig
|
||||
from graphrag.config.models.cluster_graph_config import ClusterGraphConfig
|
||||
from graphrag.config.models.community_reports_config import CommunityReportsConfig
|
||||
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
|
||||
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
|
||||
from graphrag.config.models.entity_extraction_config import EntityExtractionConfig
|
||||
from graphrag.config.models.global_search_config import GlobalSearchConfig
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.config.models.llm_parameters import LLMParameters
|
||||
from graphrag.config.models.local_search_config import LocalSearchConfig
|
||||
from graphrag.config.models.parallelization_parameters import ParallelizationParameters
|
||||
from graphrag.config.models.reporting_config import ReportingConfig
|
||||
from graphrag.config.models.snapshots_config import SnapshotsConfig
|
||||
from graphrag.config.models.storage_config import StorageConfig
|
||||
from graphrag.config.models.summarize_descriptions_config import (
|
||||
SummarizeDescriptionsConfig,
|
||||
)
|
||||
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
|
||||
from graphrag.config.models.umap_config import UmapConfig
|
||||
from graphrag.index.config.cache import PipelineFileCacheConfig
|
||||
from graphrag.index.config.input import (
|
||||
PipelineInputConfig,
|
||||
)
|
||||
from graphrag.index.config.reporting import PipelineFileReportingConfig
|
||||
from graphrag.index.config.storage import PipelineFileStorageConfig
|
||||
from graphrag.index.create_pipeline_config import create_pipeline_config
|
||||
|
||||
current_dir = os.path.dirname(__file__)
|
||||
|
||||
ALL_ENV_VARS = {
|
||||
"GRAPHRAG_API_BASE": "http://some/base",
|
||||
"GRAPHRAG_API_KEY": "test",
|
||||
"GRAPHRAG_API_ORGANIZATION": "test_org",
|
||||
"GRAPHRAG_API_PROXY": "http://some/proxy",
|
||||
"GRAPHRAG_API_VERSION": "v1234",
|
||||
"GRAPHRAG_ASYNC_MODE": "asyncio",
|
||||
"GRAPHRAG_CACHE_STORAGE_ACCOUNT_BLOB_URL": "cache_account_blob_url",
|
||||
"GRAPHRAG_CACHE_BASE_DIR": "/some/cache/dir",
|
||||
"GRAPHRAG_CACHE_CONNECTION_STRING": "test_cs1",
|
||||
"GRAPHRAG_CACHE_CONTAINER_NAME": "test_cn1",
|
||||
"GRAPHRAG_CACHE_TYPE": "blob",
|
||||
"GRAPHRAG_CHUNK_BY_COLUMNS": "a,b",
|
||||
"GRAPHRAG_CHUNK_OVERLAP": "12",
|
||||
"GRAPHRAG_CHUNK_SIZE": "500",
|
||||
"GRAPHRAG_CHUNK_ENCODING_MODEL": "encoding-c",
|
||||
"GRAPHRAG_CLAIM_EXTRACTION_ENABLED": "True",
|
||||
"GRAPHRAG_CLAIM_EXTRACTION_DESCRIPTION": "test 123",
|
||||
"GRAPHRAG_CLAIM_EXTRACTION_MAX_GLEANINGS": "5000",
|
||||
"GRAPHRAG_CLAIM_EXTRACTION_PROMPT_FILE": "tests/unit/config/prompt-a.txt",
|
||||
"GRAPHRAG_CLAIM_EXTRACTION_ENCODING_MODEL": "encoding_a",
|
||||
"GRAPHRAG_COMMUNITY_REPORTS_MAX_LENGTH": "23456",
|
||||
"GRAPHRAG_COMMUNITY_REPORTS_PROMPT_FILE": "tests/unit/config/prompt-b.txt",
|
||||
"GRAPHRAG_EMBEDDING_BATCH_MAX_TOKENS": "17",
|
||||
"GRAPHRAG_EMBEDDING_BATCH_SIZE": "1000000",
|
||||
"GRAPHRAG_EMBEDDING_CONCURRENT_REQUESTS": "12",
|
||||
"GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME": "model-deployment-name",
|
||||
"GRAPHRAG_EMBEDDING_MAX_RETRIES": "3",
|
||||
"GRAPHRAG_EMBEDDING_MAX_RETRY_WAIT": "0.1123",
|
||||
"GRAPHRAG_EMBEDDING_MODEL": "text-embedding-2",
|
||||
"GRAPHRAG_EMBEDDING_REQUESTS_PER_MINUTE": "500",
|
||||
"GRAPHRAG_EMBEDDING_SKIP": "a1,b1,c1",
|
||||
"GRAPHRAG_EMBEDDING_SLEEP_ON_RATE_LIMIT_RECOMMENDATION": "False",
|
||||
"GRAPHRAG_EMBEDDING_TARGET": "all",
|
||||
"GRAPHRAG_EMBEDDING_THREAD_COUNT": "2345",
|
||||
"GRAPHRAG_EMBEDDING_THREAD_STAGGER": "0.456",
|
||||
"GRAPHRAG_EMBEDDING_TOKENS_PER_MINUTE": "7000",
|
||||
"GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding",
|
||||
"GRAPHRAG_ENCODING_MODEL": "test123",
|
||||
"GRAPHRAG_INPUT_STORAGE_ACCOUNT_BLOB_URL": "input_account_blob_url",
|
||||
"GRAPHRAG_ENTITY_EXTRACTION_ENTITY_TYPES": "cat,dog,elephant",
|
||||
"GRAPHRAG_ENTITY_EXTRACTION_MAX_GLEANINGS": "112",
|
||||
"GRAPHRAG_ENTITY_EXTRACTION_PROMPT_FILE": "tests/unit/config/prompt-c.txt",
|
||||
"GRAPHRAG_ENTITY_EXTRACTION_ENCODING_MODEL": "encoding_b",
|
||||
"GRAPHRAG_INPUT_BASE_DIR": "/some/input/dir",
|
||||
"GRAPHRAG_INPUT_CONNECTION_STRING": "input_cs",
|
||||
"GRAPHRAG_INPUT_CONTAINER_NAME": "input_cn",
|
||||
"GRAPHRAG_INPUT_DOCUMENT_ATTRIBUTE_COLUMNS": "test1,test2",
|
||||
"GRAPHRAG_INPUT_ENCODING": "utf-16",
|
||||
"GRAPHRAG_INPUT_FILE_PATTERN": ".*\\test\\.txt$",
|
||||
"GRAPHRAG_INPUT_SOURCE_COLUMN": "test_source",
|
||||
"GRAPHRAG_INPUT_TYPE": "blob",
|
||||
"GRAPHRAG_INPUT_TEXT_COLUMN": "test_text",
|
||||
"GRAPHRAG_INPUT_TIMESTAMP_COLUMN": "test_timestamp",
|
||||
"GRAPHRAG_INPUT_TIMESTAMP_FORMAT": "test_format",
|
||||
"GRAPHRAG_INPUT_TITLE_COLUMN": "test_title",
|
||||
"GRAPHRAG_INPUT_FILE_TYPE": "text",
|
||||
"GRAPHRAG_LLM_CONCURRENT_REQUESTS": "12",
|
||||
"GRAPHRAG_LLM_DEPLOYMENT_NAME": "model-deployment-name-x",
|
||||
"GRAPHRAG_LLM_MAX_RETRIES": "312",
|
||||
"GRAPHRAG_LLM_MAX_RETRY_WAIT": "0.1122",
|
||||
"GRAPHRAG_LLM_MAX_TOKENS": "15000",
|
||||
"GRAPHRAG_LLM_MODEL_SUPPORTS_JSON": "true",
|
||||
"GRAPHRAG_LLM_MODEL": "test-llm",
|
||||
"GRAPHRAG_LLM_N": "1",
|
||||
"GRAPHRAG_LLM_REQUEST_TIMEOUT": "12.7",
|
||||
"GRAPHRAG_LLM_REQUESTS_PER_MINUTE": "900",
|
||||
"GRAPHRAG_LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION": "False",
|
||||
"GRAPHRAG_LLM_THREAD_COUNT": "987",
|
||||
"GRAPHRAG_LLM_THREAD_STAGGER": "0.123",
|
||||
"GRAPHRAG_LLM_TOKENS_PER_MINUTE": "8000",
|
||||
"GRAPHRAG_LLM_TYPE": "azure_openai_chat",
|
||||
"GRAPHRAG_MAX_CLUSTER_SIZE": "123",
|
||||
"GRAPHRAG_NODE2VEC_ENABLED": "true",
|
||||
"GRAPHRAG_NODE2VEC_ITERATIONS": "878787",
|
||||
"GRAPHRAG_NODE2VEC_NUM_WALKS": "5000000",
|
||||
"GRAPHRAG_NODE2VEC_RANDOM_SEED": "010101",
|
||||
"GRAPHRAG_NODE2VEC_WALK_LENGTH": "555111",
|
||||
"GRAPHRAG_NODE2VEC_WINDOW_SIZE": "12345",
|
||||
"GRAPHRAG_REPORTING_STORAGE_ACCOUNT_BLOB_URL": "reporting_account_blob_url",
|
||||
"GRAPHRAG_REPORTING_BASE_DIR": "/some/reporting/dir",
|
||||
"GRAPHRAG_REPORTING_CONNECTION_STRING": "test_cs2",
|
||||
"GRAPHRAG_REPORTING_CONTAINER_NAME": "test_cn2",
|
||||
"GRAPHRAG_REPORTING_TYPE": "blob",
|
||||
"GRAPHRAG_SKIP_WORKFLOWS": "a,b,c",
|
||||
"GRAPHRAG_SNAPSHOT_GRAPHML": "true",
|
||||
"GRAPHRAG_SNAPSHOT_RAW_ENTITIES": "true",
|
||||
"GRAPHRAG_SNAPSHOT_TOP_LEVEL_NODES": "true",
|
||||
"GRAPHRAG_SNAPSHOT_EMBEDDINGS": "true",
|
||||
"GRAPHRAG_SNAPSHOT_TRANSIENT": "true",
|
||||
"GRAPHRAG_STORAGE_STORAGE_ACCOUNT_BLOB_URL": "storage_account_blob_url",
|
||||
"GRAPHRAG_STORAGE_BASE_DIR": "/some/storage/dir",
|
||||
"GRAPHRAG_STORAGE_CONNECTION_STRING": "test_cs",
|
||||
"GRAPHRAG_STORAGE_CONTAINER_NAME": "test_cn",
|
||||
"GRAPHRAG_STORAGE_TYPE": "blob",
|
||||
"GRAPHRAG_SUMMARIZE_DESCRIPTIONS_MAX_LENGTH": "12345",
|
||||
"GRAPHRAG_SUMMARIZE_DESCRIPTIONS_PROMPT_FILE": "tests/unit/config/prompt-d.txt",
|
||||
"GRAPHRAG_LLM_TEMPERATURE": "0.0",
|
||||
"GRAPHRAG_LLM_TOP_P": "1.0",
|
||||
"GRAPHRAG_UMAP_ENABLED": "true",
|
||||
"GRAPHRAG_LOCAL_SEARCH_TEXT_UNIT_PROP": "0.713",
|
||||
"GRAPHRAG_LOCAL_SEARCH_COMMUNITY_PROP": "0.1234",
|
||||
"GRAPHRAG_LOCAL_SEARCH_LLM_TEMPERATURE": "0.1",
|
||||
"GRAPHRAG_LOCAL_SEARCH_LLM_TOP_P": "0.9",
|
||||
"GRAPHRAG_LOCAL_SEARCH_LLM_N": "2",
|
||||
"GRAPHRAG_LOCAL_SEARCH_LLM_MAX_TOKENS": "12",
|
||||
"GRAPHRAG_LOCAL_SEARCH_TOP_K_RELATIONSHIPS": "15",
|
||||
"GRAPHRAG_LOCAL_SEARCH_TOP_K_ENTITIES": "14",
|
||||
"GRAPHRAG_LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS": "2",
|
||||
"GRAPHRAG_LOCAL_SEARCH_MAX_TOKENS": "142435",
|
||||
"GRAPHRAG_GLOBAL_SEARCH_LLM_TEMPERATURE": "0.1",
|
||||
"GRAPHRAG_GLOBAL_SEARCH_LLM_TOP_P": "0.9",
|
||||
"GRAPHRAG_GLOBAL_SEARCH_LLM_N": "2",
|
||||
"GRAPHRAG_GLOBAL_SEARCH_MAX_TOKENS": "5123",
|
||||
"GRAPHRAG_GLOBAL_SEARCH_DATA_MAX_TOKENS": "123",
|
||||
"GRAPHRAG_GLOBAL_SEARCH_MAP_MAX_TOKENS": "4123",
|
||||
"GRAPHRAG_GLOBAL_SEARCH_CONCURRENCY": "7",
|
||||
"GRAPHRAG_GLOBAL_SEARCH_REDUCE_MAX_TOKENS": "15432",
|
||||
}
|
||||
|
||||
|
||||
class TestDefaultConfig(unittest.TestCase):
|
||||
def test_clear_warnings(self):
|
||||
"""Just clearing unused import warnings"""
|
||||
assert CacheConfig is not None
|
||||
assert ChunkingConfig is not None
|
||||
assert ClaimExtractionConfig is not None
|
||||
assert ClusterGraphConfig is not None
|
||||
assert CommunityReportsConfig is not None
|
||||
assert DRIFTSearchConfig is not None
|
||||
assert EmbedGraphConfig is not None
|
||||
assert EntityExtractionConfig is not None
|
||||
assert GlobalSearchConfig is not None
|
||||
assert GraphRagConfig is not None
|
||||
assert InputConfig is not None
|
||||
assert LLMParameters is not None
|
||||
assert LocalSearchConfig is not None
|
||||
assert BasicSearchConfig is not None
|
||||
assert ParallelizationParameters is not None
|
||||
assert ReportingConfig is not None
|
||||
assert SnapshotsConfig is not None
|
||||
assert StorageConfig is not None
|
||||
assert SummarizeDescriptionsConfig is not None
|
||||
assert TextEmbeddingConfig is not None
|
||||
assert UmapConfig is not None
|
||||
assert PipelineFileReportingConfig is not None
|
||||
assert PipelineFileStorageConfig is not None
|
||||
assert PipelineInputConfig is not None
|
||||
assert PipelineFileCacheConfig is not None
|
||||
|
||||
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "test"}, clear=True)
|
||||
def test_string_repr(self):
|
||||
# __str__ can be json loaded
|
||||
config = create_graphrag_config()
|
||||
string_repr = str(config)
|
||||
assert string_repr is not None
|
||||
assert json.loads(string_repr) is not None
|
||||
|
||||
# __repr__ can be eval()'d
|
||||
repr_str = config.__repr__()
|
||||
# TODO: add __repr__ to enum
|
||||
repr_str = repr_str.replace("async_mode=<AsyncType.Threaded: 'threaded'>,", "")
|
||||
assert eval(repr_str) is not None
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
def test_default_config_with_no_env_vars_throws(self):
|
||||
with pytest.raises(ApiKeyMissingError):
|
||||
# This should throw an error because the API key is missing
|
||||
create_graphrag_config()
|
||||
|
||||
@mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True)
|
||||
def test_default_config_with_api_key_passes(self):
|
||||
# doesn't throw
|
||||
config = create_graphrag_config()
|
||||
assert config is not None
|
||||
|
||||
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "test"}, clear=True)
|
||||
def test_default_config_with_oai_key_passes_envvar(self):
|
||||
# doesn't throw
|
||||
config = create_graphrag_config()
|
||||
assert config is not None
|
||||
|
||||
def test_default_config_with_oai_key_passes_obj(self):
|
||||
# doesn't throw
|
||||
config = create_graphrag_config({"llm": {"api_key": "test"}})
|
||||
assert config is not None
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{"GRAPHRAG_API_KEY": "test", "GRAPHRAG_LLM_TYPE": "azure_openai_chat"},
|
||||
clear=True,
|
||||
)
|
||||
def test_throws_if_azure_is_used_without_api_base_envvar(self):
|
||||
with pytest.raises(AzureApiBaseMissingError):
|
||||
create_graphrag_config()
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"GRAPHRAG_API_KEY": "test",
|
||||
"GRAPHRAG_LLM_TYPE": "azure_openai_chat",
|
||||
"GRAPHRAG_API_BASE": "http://some/base",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
def test_throws_if_azure_is_used_without_llm_deployment_name_envvar(self):
|
||||
with pytest.raises(AzureDeploymentNameMissingError):
|
||||
create_graphrag_config()
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"GRAPHRAG_API_KEY": "test",
|
||||
"GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding",
|
||||
"GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME": "x",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
def test_throws_if_azure_is_used_without_embedding_api_base_envvar(self):
|
||||
with pytest.raises(AzureApiBaseMissingError):
|
||||
create_graphrag_config()
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"GRAPHRAG_API_KEY": "test",
|
||||
"GRAPHRAG_API_BASE": "http://some/base",
|
||||
"GRAPHRAG_LLM_DEPLOYMENT_NAME": "x",
|
||||
"GRAPHRAG_LLM_TYPE": "azure_openai_chat",
|
||||
"GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
def test_throws_if_azure_is_used_without_embedding_deployment_name_envvar(self):
|
||||
with pytest.raises(AzureDeploymentNameMissingError):
|
||||
create_graphrag_config()
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"GRAPHRAG_API_KEY": "test",
|
||||
"GRAPHRAG_LLM_TYPE": "azure_openai_chat",
|
||||
"GRAPHRAG_LLM_DEPLOYMENT_NAME": "x",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
def test_throws_if_azure_is_used_without_api_base(self):
|
||||
with pytest.raises(AzureApiBaseMissingError):
|
||||
create_graphrag_config()
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"GRAPHRAG_API_KEY": "test",
|
||||
"GRAPHRAG_LLM_TYPE": "azure_openai_chat",
|
||||
"GRAPHRAG_LLM_API_BASE": "http://some/base",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
def test_throws_if_azure_is_used_without_llm_deployment_name(self):
|
||||
with pytest.raises(AzureDeploymentNameMissingError):
|
||||
create_graphrag_config()
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"GRAPHRAG_API_KEY": "test",
|
||||
"GRAPHRAG_LLM_TYPE": "azure_openai_chat",
|
||||
"GRAPHRAG_API_BASE": "http://some/base",
|
||||
"GRAPHRAG_LLM_DEPLOYMENT_NAME": "model-deployment-name-x",
|
||||
"GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
def test_throws_if_azure_is_used_without_embedding_deployment_name(self):
|
||||
with pytest.raises(AzureDeploymentNameMissingError):
|
||||
create_graphrag_config()
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"GRAPHRAG_LLM_API_KEY": "test",
|
||||
"GRAPHRAG_ENTITY_EXTRACTION_MAX_GLEANINGS": "0",
|
||||
"GRAPHRAG_CLAIM_EXTRACTION_MAX_GLEANINGS": "0",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
def test_can_set_gleanings_to_zero(self):
|
||||
parameters = create_graphrag_config()
|
||||
assert parameters.claim_extraction.max_gleanings == 0
|
||||
assert parameters.entity_extraction.max_gleanings == 0
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{"GRAPHRAG_LLM_API_KEY": "test", "GRAPHRAG_CHUNK_BY_COLUMNS": ""},
|
||||
clear=True,
|
||||
)
|
||||
def test_can_set_no_chunk_by_columns(self):
|
||||
parameters = create_graphrag_config()
|
||||
assert parameters.chunks.group_by_columns == []
|
||||
|
||||
def test_all_env_vars_is_accurate(self):
|
||||
env_var_docs_path = Path("docs/config/env_vars.md")
|
||||
|
||||
env_var_docs = env_var_docs_path.read_text(encoding="utf-8")
|
||||
|
||||
def find_envvar_names(text) -> set[str]:
|
||||
pattern = r"`(GRAPHRAG_[^`]+)`"
|
||||
found = re.findall(pattern, text)
|
||||
found = {f for f in found if not f.endswith("_")}
|
||||
return {*found}
|
||||
|
||||
graphrag_strings = find_envvar_names(env_var_docs)
|
||||
|
||||
missing = {s for s in graphrag_strings if s not in ALL_ENV_VARS} - {
|
||||
# Remove configs covered by the base LLM connection configs
|
||||
"GRAPHRAG_LLM_API_KEY",
|
||||
"GRAPHRAG_LLM_API_BASE",
|
||||
"GRAPHRAG_LLM_API_VERSION",
|
||||
"GRAPHRAG_LLM_API_ORGANIZATION",
|
||||
"GRAPHRAG_LLM_API_PROXY",
|
||||
"GRAPHRAG_EMBEDDING_API_KEY",
|
||||
"GRAPHRAG_EMBEDDING_API_BASE",
|
||||
"GRAPHRAG_EMBEDDING_API_VERSION",
|
||||
"GRAPHRAG_EMBEDDING_API_ORGANIZATION",
|
||||
"GRAPHRAG_EMBEDDING_API_PROXY",
|
||||
}
|
||||
if missing:
|
||||
msg = f"{len(missing)} missing env vars: {missing}"
|
||||
print(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{"GRAPHRAG_API_KEY": "test"},
|
||||
clear=True,
|
||||
)
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
ALL_ENV_VARS,
|
||||
clear=True,
|
||||
)
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{"GRAPHRAG_API_KEY": "test"},
|
||||
clear=True,
|
||||
)
|
||||
def test_default_values(self) -> None:
|
||||
parameters = create_graphrag_config()
|
||||
assert parameters.async_mode == defs.ASYNC_MODE
|
||||
assert parameters.cache.base_dir == defs.CACHE_BASE_DIR
|
||||
assert parameters.cache.type == defs.CACHE_TYPE
|
||||
assert parameters.cache.base_dir == defs.CACHE_BASE_DIR
|
||||
assert parameters.chunks.group_by_columns == defs.CHUNK_GROUP_BY_COLUMNS
|
||||
assert parameters.chunks.overlap == defs.CHUNK_OVERLAP
|
||||
assert parameters.chunks.size == defs.CHUNK_SIZE
|
||||
assert parameters.claim_extraction.description == defs.CLAIM_DESCRIPTION
|
||||
assert parameters.claim_extraction.max_gleanings == defs.CLAIM_MAX_GLEANINGS
|
||||
assert (
|
||||
parameters.community_reports.max_input_length
|
||||
== defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH
|
||||
)
|
||||
assert (
|
||||
parameters.community_reports.max_length == defs.COMMUNITY_REPORT_MAX_LENGTH
|
||||
)
|
||||
assert parameters.embeddings.batch_max_tokens == defs.EMBEDDING_BATCH_MAX_TOKENS
|
||||
assert parameters.embeddings.batch_size == defs.EMBEDDING_BATCH_SIZE
|
||||
assert parameters.embeddings.llm.model == defs.EMBEDDING_MODEL
|
||||
assert parameters.embeddings.target == defs.EMBEDDING_TARGET
|
||||
assert parameters.embeddings.llm.type == defs.EMBEDDING_TYPE
|
||||
assert (
|
||||
parameters.embeddings.llm.requests_per_minute
|
||||
== defs.LLM_REQUESTS_PER_MINUTE
|
||||
)
|
||||
assert parameters.embeddings.llm.tokens_per_minute == defs.LLM_TOKENS_PER_MINUTE
|
||||
assert (
|
||||
parameters.embeddings.llm.sleep_on_rate_limit_recommendation
|
||||
== defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION
|
||||
)
|
||||
assert (
|
||||
parameters.entity_extraction.entity_types
|
||||
== defs.ENTITY_EXTRACTION_ENTITY_TYPES
|
||||
)
|
||||
assert (
|
||||
parameters.entity_extraction.max_gleanings
|
||||
== defs.ENTITY_EXTRACTION_MAX_GLEANINGS
|
||||
)
|
||||
assert parameters.encoding_model == defs.ENCODING_MODEL
|
||||
assert parameters.input.base_dir == defs.INPUT_BASE_DIR
|
||||
assert parameters.input.file_pattern == defs.INPUT_CSV_PATTERN
|
||||
assert parameters.input.encoding == defs.INPUT_FILE_ENCODING
|
||||
assert parameters.input.type == defs.INPUT_TYPE
|
||||
assert parameters.input.base_dir == defs.INPUT_BASE_DIR
|
||||
assert parameters.input.text_column == defs.INPUT_TEXT_COLUMN
|
||||
assert parameters.input.file_type == defs.INPUT_FILE_TYPE
|
||||
assert parameters.llm.concurrent_requests == defs.LLM_CONCURRENT_REQUESTS
|
||||
assert parameters.llm.max_retries == defs.LLM_MAX_RETRIES
|
||||
assert parameters.llm.max_retry_wait == defs.LLM_MAX_RETRY_WAIT
|
||||
assert parameters.llm.max_tokens == defs.LLM_MAX_TOKENS
|
||||
assert parameters.llm.model == defs.LLM_MODEL
|
||||
assert parameters.llm.request_timeout == defs.LLM_REQUEST_TIMEOUT
|
||||
assert parameters.llm.requests_per_minute == defs.LLM_REQUESTS_PER_MINUTE
|
||||
assert parameters.llm.tokens_per_minute == defs.LLM_TOKENS_PER_MINUTE
|
||||
assert (
|
||||
parameters.llm.sleep_on_rate_limit_recommendation
|
||||
== defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION
|
||||
)
|
||||
assert parameters.llm.type == defs.LLM_TYPE
|
||||
assert parameters.cluster_graph.max_cluster_size == defs.MAX_CLUSTER_SIZE
|
||||
assert parameters.embed_graph.enabled == defs.NODE2VEC_ENABLED
|
||||
assert parameters.embed_graph.iterations == defs.NODE2VEC_ITERATIONS
|
||||
assert parameters.embed_graph.num_walks == defs.NODE2VEC_NUM_WALKS
|
||||
assert parameters.embed_graph.random_seed == defs.NODE2VEC_RANDOM_SEED
|
||||
assert parameters.embed_graph.walk_length == defs.NODE2VEC_WALK_LENGTH
|
||||
assert parameters.embed_graph.window_size == defs.NODE2VEC_WINDOW_SIZE
|
||||
assert (
|
||||
parameters.parallelization.num_threads == defs.PARALLELIZATION_NUM_THREADS
|
||||
)
|
||||
assert parameters.parallelization.stagger == defs.PARALLELIZATION_STAGGER
|
||||
assert parameters.reporting.type == defs.REPORTING_TYPE
|
||||
assert parameters.reporting.base_dir == defs.REPORTING_BASE_DIR
|
||||
assert parameters.snapshots.graphml == defs.SNAPSHOTS_GRAPHML
|
||||
assert parameters.snapshots.embeddings == defs.SNAPSHOTS_EMBEDDINGS
|
||||
assert parameters.snapshots.transient == defs.SNAPSHOTS_TRANSIENT
|
||||
assert parameters.storage.base_dir == defs.STORAGE_BASE_DIR
|
||||
assert parameters.storage.type == defs.STORAGE_TYPE
|
||||
assert parameters.umap.enabled == defs.UMAP_ENABLED
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{"GRAPHRAG_API_KEY": "test"},
|
||||
clear=True,
|
||||
)
|
||||
def test_prompt_file_reading(self):
|
||||
config = create_graphrag_config({
|
||||
"entity_extraction": {"prompt": "tests/unit/config/prompt-a.txt"},
|
||||
"claim_extraction": {"prompt": "tests/unit/config/prompt-b.txt"},
|
||||
"community_reports": {"prompt": "tests/unit/config/prompt-c.txt"},
|
||||
"summarize_descriptions": {"prompt": "tests/unit/config/prompt-d.txt"},
|
||||
})
|
||||
strategy = config.entity_extraction.resolved_strategy(".", "abc123")
|
||||
assert strategy["extraction_prompt"] == "Hello, World! A"
|
||||
assert strategy["encoding_name"] == "abc123"
|
||||
|
||||
strategy = config.claim_extraction.resolved_strategy(".", "encoding_b")
|
||||
assert strategy["extraction_prompt"] == "Hello, World! B"
|
||||
|
||||
strategy = config.community_reports.resolved_strategy(".")
|
||||
assert strategy["extraction_prompt"] == "Hello, World! C"
|
||||
|
||||
strategy = config.summarize_descriptions.resolved_strategy(".")
|
||||
assert strategy["summarize_prompt"] == "Hello, World! D"
|
||||
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"PIPELINE_LLM_API_KEY": "test",
|
||||
"PIPELINE_LLM_API_BASE": "http://test",
|
||||
"PIPELINE_LLM_API_VERSION": "v1",
|
||||
"PIPELINE_LLM_MODEL": "test-llm",
|
||||
"PIPELINE_LLM_DEPLOYMENT_NAME": "test",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
def test_yaml_load_e2e():
|
||||
config_dict = yaml.safe_load(
|
||||
"""
|
||||
input:
|
||||
file_type: text
|
||||
|
||||
llm:
|
||||
type: azure_openai_chat
|
||||
api_key: ${PIPELINE_LLM_API_KEY}
|
||||
api_base: ${PIPELINE_LLM_API_BASE}
|
||||
api_version: ${PIPELINE_LLM_API_VERSION}
|
||||
model: ${PIPELINE_LLM_MODEL}
|
||||
deployment_name: ${PIPELINE_LLM_DEPLOYMENT_NAME}
|
||||
model_supports_json: True
|
||||
tokens_per_minute: 80000
|
||||
requests_per_minute: 900
|
||||
thread_count: 50
|
||||
concurrent_requests: 25
|
||||
"""
|
||||
)
|
||||
# create default configuration pipeline parameters from the custom settings
|
||||
model = config_dict
|
||||
parameters = create_graphrag_config(model, ".")
|
||||
|
||||
assert parameters.llm.api_key == "test"
|
||||
assert parameters.llm.model == "test-llm"
|
||||
assert parameters.llm.api_base == "http://test"
|
||||
assert parameters.llm.api_version == "v1"
|
||||
assert parameters.llm.deployment_name == "test"
|
||||
|
||||
# generate the pipeline from the default parameters
|
||||
pipeline_config = create_pipeline_config(parameters, True)
|
||||
|
||||
config_str = pipeline_config.model_dump_json()
|
||||
assert "${PIPELINE_LLM_API_KEY}" not in config_str
|
||||
assert "${PIPELINE_LLM_API_BASE}" not in config_str
|
||||
assert "${PIPELINE_LLM_API_VERSION}" not in config_str
|
||||
assert "${PIPELINE_LLM_MODEL}" not in config_str
|
||||
assert "${PIPELINE_LLM_DEPLOYMENT_NAME}" not in config_str
|
||||
@ -1,42 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag.config.resolve_path import resolve_path
|
||||
|
||||
|
||||
def test_resolve_path_no_timestamp_with_run_id():
|
||||
path = Path("path/to/data")
|
||||
result = resolve_path(path, pattern_or_timestamp_value="20240812-121000")
|
||||
assert result == path
|
||||
|
||||
|
||||
def test_resolve_path_no_timestamp_without_run_id():
|
||||
path = Path("path/to/data")
|
||||
result = resolve_path(path)
|
||||
assert result == path
|
||||
|
||||
|
||||
def test_resolve_path_with_timestamp_and_run_id():
|
||||
path = Path("some/path/${timestamp}/data")
|
||||
expected = Path("some/path/20240812/data")
|
||||
result = resolve_path(path, pattern_or_timestamp_value="20240812")
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_resolve_path_with_timestamp_and_inferred_directory():
|
||||
cwd = Path(__file__).parent
|
||||
path = cwd / "fixtures/timestamp_dirs/${timestamp}/data"
|
||||
expected = cwd / "fixtures/timestamp_dirs/20240812-120000/data"
|
||||
result = resolve_path(path)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_resolve_path_absolute():
|
||||
cwd = Path(__file__).parent
|
||||
path = "fixtures/timestamp_dirs/${timestamp}/data"
|
||||
expected = cwd / "fixtures/timestamp_dirs/20240812-120000/data"
|
||||
result = resolve_path(path, cwd)
|
||||
assert result == expected
|
||||
assert result.is_absolute()
|
||||
569
tests/unit/config/utils.py
Normal file
569
tests/unit/config/utils.py
Normal file
@ -0,0 +1,569 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.models.basic_search_config import BasicSearchConfig
|
||||
from graphrag.config.models.cache_config import CacheConfig
|
||||
from graphrag.config.models.chunking_config import ChunkingConfig
|
||||
from graphrag.config.models.claim_extraction_config import ClaimExtractionConfig
|
||||
from graphrag.config.models.cluster_graph_config import ClusterGraphConfig
|
||||
from graphrag.config.models.community_reports_config import CommunityReportsConfig
|
||||
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
|
||||
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
|
||||
from graphrag.config.models.entity_extraction_config import EntityExtractionConfig
|
||||
from graphrag.config.models.global_search_config import GlobalSearchConfig
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.config.models.local_search_config import LocalSearchConfig
|
||||
from graphrag.config.models.output_config import OutputConfig
|
||||
from graphrag.config.models.reporting_config import ReportingConfig
|
||||
from graphrag.config.models.snapshots_config import SnapshotsConfig
|
||||
from graphrag.config.models.summarize_descriptions_config import (
|
||||
SummarizeDescriptionsConfig,
|
||||
)
|
||||
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
|
||||
from graphrag.config.models.umap_config import UmapConfig
|
||||
from graphrag.config.models.vector_store_config import VectorStoreConfig
|
||||
|
||||
FAKE_API_KEY = "NOT_AN_API_KEY"
|
||||
|
||||
DEFAULT_CHAT_MODEL_CONFIG = {
|
||||
"api_key": FAKE_API_KEY,
|
||||
"type": defs.LLM_TYPE.value,
|
||||
"model": defs.LLM_MODEL,
|
||||
}
|
||||
|
||||
DEFAULT_EMBEDDING_MODEL_CONFIG = {
|
||||
"api_key": FAKE_API_KEY,
|
||||
"type": defs.EMBEDDING_TYPE.value,
|
||||
"model": defs.EMBEDDING_MODEL,
|
||||
}
|
||||
|
||||
DEFAULT_MODEL_CONFIG = {
|
||||
defs.DEFAULT_CHAT_MODEL_ID: DEFAULT_CHAT_MODEL_CONFIG,
|
||||
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
|
||||
}
|
||||
|
||||
DEFAULT_GRAPHRAG_CONFIG_SETTINGS = {
|
||||
"models": DEFAULT_MODEL_CONFIG,
|
||||
"vector_store": {
|
||||
"type": defs.VECTOR_STORE_TYPE,
|
||||
"db_uri": defs.VECTOR_STORE_DB_URI,
|
||||
"container_name": defs.VECTOR_STORE_CONTAINER_NAME,
|
||||
"overwrite": defs.VECTOR_STORE_OVERWRITE,
|
||||
"url": None,
|
||||
"api_key": None,
|
||||
"audience": None,
|
||||
},
|
||||
"reporting": {
|
||||
"type": defs.REPORTING_TYPE,
|
||||
"base_dir": defs.REPORTING_BASE_DIR,
|
||||
"connection_string": None,
|
||||
"container_name": None,
|
||||
"storage_account_blob_url": None,
|
||||
},
|
||||
"output": {
|
||||
"type": defs.OUTPUT_TYPE,
|
||||
"base_dir": defs.OUTPUT_BASE_DIR,
|
||||
"connection_string": None,
|
||||
"container_name": None,
|
||||
"storage_account_blob_url": None,
|
||||
},
|
||||
"update_index_output": None,
|
||||
"cache": {
|
||||
"type": defs.CACHE_TYPE,
|
||||
"base_dir": defs.CACHE_BASE_DIR,
|
||||
"connection_string": None,
|
||||
"container_name": None,
|
||||
"storage_account_blob_url": None,
|
||||
"cosmosdb_account_url": None,
|
||||
},
|
||||
"input": {
|
||||
"type": defs.INPUT_TYPE,
|
||||
"file_type": defs.INPUT_FILE_TYPE,
|
||||
"base_dir": defs.INPUT_BASE_DIR,
|
||||
"connection_string": None,
|
||||
"storage_account_blob_url": None,
|
||||
"container_name": None,
|
||||
"encoding": defs.INPUT_FILE_ENCODING,
|
||||
"file_pattern": defs.INPUT_TEXT_PATTERN,
|
||||
"file_filter": None,
|
||||
"source_column": None,
|
||||
"timestamp_column": None,
|
||||
"timestamp_format": None,
|
||||
"text_column": defs.INPUT_TEXT_COLUMN,
|
||||
"title_column": None,
|
||||
"document_attribute_columns": [],
|
||||
},
|
||||
"embed_graph": {
|
||||
"enabled": defs.NODE2VEC_ENABLED,
|
||||
"dimensions": defs.NODE2VEC_DIMENSIONS,
|
||||
"num_walks": defs.NODE2VEC_NUM_WALKS,
|
||||
"walk_length": defs.NODE2VEC_WALK_LENGTH,
|
||||
"window_size": defs.NODE2VEC_WINDOW_SIZE,
|
||||
"iterations": defs.NODE2VEC_ITERATIONS,
|
||||
"random_seed": defs.NODE2VEC_RANDOM_SEED,
|
||||
"use_lcc": defs.USE_LCC,
|
||||
},
|
||||
"embeddings": {
|
||||
"batch_size": defs.EMBEDDING_BATCH_SIZE,
|
||||
"batch_max_tokens": defs.EMBEDDING_BATCH_MAX_TOKENS,
|
||||
"target": defs.EMBEDDING_TARGET,
|
||||
"strategy": None,
|
||||
"model_id": defs.EMBEDDING_MODEL_ID,
|
||||
},
|
||||
"chunks": {
|
||||
"size": defs.CHUNK_SIZE,
|
||||
"overlap": defs.CHUNK_OVERLAP,
|
||||
"group_by_columns": defs.CHUNK_GROUP_BY_COLUMNS,
|
||||
"strategy": defs.CHUNK_STRATEGY,
|
||||
"encoding_model": defs.ENCODING_MODEL,
|
||||
},
|
||||
"snapshots": {
|
||||
"embeddings": defs.SNAPSHOTS_EMBEDDINGS,
|
||||
"graphml": defs.SNAPSHOTS_GRAPHML,
|
||||
"transient": defs.SNAPSHOTS_TRANSIENT,
|
||||
},
|
||||
"entity_extraction": {
|
||||
"prompt": None,
|
||||
"entity_types": defs.ENTITY_EXTRACTION_ENTITY_TYPES,
|
||||
"max_gleanings": defs.ENTITY_EXTRACTION_MAX_GLEANINGS,
|
||||
"strategy": None,
|
||||
"encoding_model": None,
|
||||
"model_id": defs.ENTITY_EXTRACTION_MODEL_ID,
|
||||
},
|
||||
"summarize_descriptions": {
|
||||
"prompt": None,
|
||||
"max_length": defs.SUMMARIZE_DESCRIPTIONS_MAX_LENGTH,
|
||||
"strategy": None,
|
||||
"model_id": defs.SUMMARIZE_MODEL_ID,
|
||||
},
|
||||
"community_report": {
|
||||
"prompt": None,
|
||||
"max_length": defs.COMMUNITY_REPORT_MAX_LENGTH,
|
||||
"max_input_length": defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH,
|
||||
"strategy": None,
|
||||
"model_id": defs.COMMUNITY_REPORT_MODEL_ID,
|
||||
},
|
||||
"claim_extaction": {
|
||||
"enabled": defs.CLAIM_EXTRACTION_ENABLED,
|
||||
"prompt": None,
|
||||
"description": defs.CLAIM_DESCRIPTION,
|
||||
"max_gleanings": defs.CLAIM_MAX_GLEANINGS,
|
||||
"strategy": None,
|
||||
"encoding_model": None,
|
||||
"model_id": defs.CLAIM_EXTRACTION_MODEL_ID,
|
||||
},
|
||||
"cluster_graph": {
|
||||
"max_cluster_size": defs.MAX_CLUSTER_SIZE,
|
||||
"use_lcc": defs.USE_LCC,
|
||||
"seed": defs.CLUSTER_GRAPH_SEED,
|
||||
},
|
||||
"umap": {"enabled": defs.UMAP_ENABLED},
|
||||
"local_search": {
|
||||
"prompt": None,
|
||||
"text_unit_prop": defs.LOCAL_SEARCH_TEXT_UNIT_PROP,
|
||||
"community_prop": defs.LOCAL_SEARCH_COMMUNITY_PROP,
|
||||
"conversation_history_max_turns": defs.LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS,
|
||||
"top_k_entities": defs.LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES,
|
||||
"top_k_relationships": defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS,
|
||||
"temperature": defs.LOCAL_SEARCH_LLM_TEMPERATURE,
|
||||
"top_p": defs.LOCAL_SEARCH_LLM_TOP_P,
|
||||
"n": defs.LOCAL_SEARCH_LLM_N,
|
||||
"max_tokens": defs.LOCAL_SEARCH_MAX_TOKENS,
|
||||
"llm_max_tokens": defs.LOCAL_SEARCH_LLM_MAX_TOKENS,
|
||||
},
|
||||
"global_search": {
|
||||
"map_prompt": None,
|
||||
"reduce_prompt": None,
|
||||
"knowledge_prompt": None,
|
||||
"temperature": defs.GLOBAL_SEARCH_LLM_TEMPERATURE,
|
||||
"top_p": defs.GLOBAL_SEARCH_LLM_TOP_P,
|
||||
"n": defs.GLOBAL_SEARCH_LLM_N,
|
||||
"max_tokens": defs.GLOBAL_SEARCH_MAX_TOKENS,
|
||||
"data_max_tokens": defs.GLOBAL_SEARCH_DATA_MAX_TOKENS,
|
||||
"map_max_tokens": defs.GLOBAL_SEARCH_MAP_MAX_TOKENS,
|
||||
"reduce_max_tokens": defs.GLOBAL_SEARCH_REDUCE_MAX_TOKENS,
|
||||
"concurrency": defs.GLOBAL_SEARCH_CONCURRENCY,
|
||||
"dynamic_search_llm": defs.DYNAMIC_SEARCH_LLM_MODEL,
|
||||
"dynamic_search_threshold": defs.DYNAMIC_SEARCH_RATE_THRESHOLD,
|
||||
"dynamic_search_keep_parent": defs.DYNAMIC_SEARCH_KEEP_PARENT,
|
||||
"dynamic_search_num_repeats": defs.DYNAMIC_SEARCH_NUM_REPEATS,
|
||||
"dynamic_search_use_summary": defs.DYNAMIC_SEARCH_USE_SUMMARY,
|
||||
"dynamic_search_concurrent_coroutines": defs.DYNAMIC_SEARCH_CONCURRENT_COROUTINES,
|
||||
"dynamic_search_max_level": defs.DYNAMIC_SEARCH_MAX_LEVEL,
|
||||
},
|
||||
"drift_search": {
|
||||
"prompt": None,
|
||||
"temperature": defs.DRIFT_SEARCH_LLM_TEMPERATURE,
|
||||
"top_p": defs.DRIFT_SEARCH_LLM_TOP_P,
|
||||
"n": defs.DRIFT_SEARCH_LLM_N,
|
||||
"max_tokens": defs.DRIFT_SEARCH_MAX_TOKENS,
|
||||
"data_max_tokens": defs.DRIFT_SEARCH_DATA_MAX_TOKENS,
|
||||
"concurrency": defs.DRIFT_SEARCH_CONCURRENCY,
|
||||
"drift_k_followups": defs.DRIFT_SEARCH_K_FOLLOW_UPS,
|
||||
"primer_folds": defs.DRIFT_SEARCH_PRIMER_FOLDS,
|
||||
"primer_llm_max_tokens": defs.DRIFT_SEARCH_PRIMER_MAX_TOKENS,
|
||||
"n_depth": defs.DRIFT_N_DEPTH,
|
||||
"local_search_text_unit_prop": defs.DRIFT_LOCAL_SEARCH_TEXT_UNIT_PROP,
|
||||
"local_search_community_prop": defs.DRIFT_LOCAL_SEARCH_COMMUNITY_PROP,
|
||||
"local_search_top_k_mapped_entities": defs.DRIFT_LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES,
|
||||
"local_search_top_k_relationships": defs.DRIFT_LOCAL_SEARCH_TOP_K_RELATIONSHIPS,
|
||||
"local_search_max_data_tokens": defs.DRIFT_LOCAL_SEARCH_MAX_TOKENS,
|
||||
"local_search_temperature": defs.DRIFT_LOCAL_SEARCH_LLM_TEMPERATURE,
|
||||
"local_search_top_p": defs.DRIFT_LOCAL_SEARCH_LLM_TOP_P,
|
||||
"local_search_n": defs.DRIFT_LOCAL_SEARCH_LLM_N,
|
||||
"local_search_max_tokens": defs.DRIFT_LOCAL_SEARCH_MAX_TOKENS,
|
||||
},
|
||||
"basic_search": {
|
||||
"prompt": None,
|
||||
"text_unit_prop": defs.BASIC_SEARCH_TEXT_UNIT_PROP,
|
||||
"conversation_history_max_turns": defs.BASIC_SEARCH_CONVERSATION_HISTORY_MAX_TURNS,
|
||||
"temperature": defs.BASIC_SEARCH_LLM_TEMPERATURE,
|
||||
"top_p": defs.BASIC_SEARCH_LLM_TOP_P,
|
||||
"n": defs.BASIC_SEARCH_LLM_N,
|
||||
"max_tokens": defs.BASIC_SEARCH_MAX_TOKENS,
|
||||
"llm_max_tokens": defs.BASIC_SEARCH_LLM_MAX_TOKENS,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_default_graphrag_config(root_dir: str | None = None) -> GraphRagConfig:
|
||||
if root_dir is not None and root_dir.strip() != "":
|
||||
DEFAULT_GRAPHRAG_CONFIG_SETTINGS["root_dir"] = root_dir
|
||||
|
||||
return GraphRagConfig(**DEFAULT_GRAPHRAG_CONFIG_SETTINGS)
|
||||
|
||||
|
||||
def assert_language_model_configs(
|
||||
actual: LanguageModelConfig, expected: LanguageModelConfig
|
||||
) -> None:
|
||||
assert actual.api_key == expected.api_key
|
||||
assert actual.azure_auth_type == expected.azure_auth_type
|
||||
assert actual.type == expected.type
|
||||
assert actual.model == expected.model
|
||||
assert actual.encoding_model == expected.encoding_model
|
||||
assert actual.max_tokens == expected.max_tokens
|
||||
assert actual.temperature == expected.temperature
|
||||
assert actual.top_p == expected.top_p
|
||||
assert actual.n == expected.n
|
||||
assert actual.frequency_penalty == expected.frequency_penalty
|
||||
assert actual.presence_penalty == expected.presence_penalty
|
||||
assert actual.request_timeout == expected.request_timeout
|
||||
assert actual.api_base == expected.api_base
|
||||
assert actual.api_version == expected.api_version
|
||||
assert actual.deployment_name == expected.deployment_name
|
||||
assert actual.organization == expected.organization
|
||||
assert actual.proxy == expected.proxy
|
||||
assert actual.audience == expected.audience
|
||||
assert actual.model_supports_json == expected.model_supports_json
|
||||
assert actual.tokens_per_minute == expected.tokens_per_minute
|
||||
assert actual.requests_per_minute == expected.requests_per_minute
|
||||
assert actual.max_retries == expected.max_retries
|
||||
assert actual.max_retry_wait == expected.max_retry_wait
|
||||
assert (
|
||||
actual.sleep_on_rate_limit_recommendation
|
||||
== expected.sleep_on_rate_limit_recommendation
|
||||
)
|
||||
assert actual.concurrent_requests == expected.concurrent_requests
|
||||
assert actual.parallelization_stagger == expected.parallelization_stagger
|
||||
assert actual.parallelization_num_threads == expected.parallelization_num_threads
|
||||
assert actual.async_mode == expected.async_mode
|
||||
if actual.responses is not None:
|
||||
assert expected.responses is not None
|
||||
assert len(actual.responses) == len(expected.responses)
|
||||
for e, a in zip(actual.responses, expected.responses, strict=True):
|
||||
assert isinstance(e, BaseModel)
|
||||
assert isinstance(a, BaseModel)
|
||||
assert e.model_dump() == a.model_dump()
|
||||
else:
|
||||
assert expected.responses is None
|
||||
|
||||
|
||||
def assert_vector_store_configs(actual: VectorStoreConfig, expected: VectorStoreConfig):
|
||||
assert actual.type == expected.type
|
||||
assert actual.db_uri == expected.db_uri
|
||||
assert actual.container_name == expected.container_name
|
||||
assert actual.overwrite == expected.overwrite
|
||||
assert actual.url == expected.url
|
||||
assert actual.api_key == expected.api_key
|
||||
assert actual.audience == expected.audience
|
||||
|
||||
|
||||
def assert_reporting_configs(
|
||||
actual: ReportingConfig, expected: ReportingConfig
|
||||
) -> None:
|
||||
assert actual.type == expected.type
|
||||
assert actual.base_dir == expected.base_dir
|
||||
assert actual.connection_string == expected.connection_string
|
||||
assert actual.container_name == expected.container_name
|
||||
assert actual.storage_account_blob_url == expected.storage_account_blob_url
|
||||
|
||||
|
||||
def assert_output_configs(actual: OutputConfig, expected: OutputConfig) -> None:
|
||||
assert expected.type == actual.type
|
||||
assert expected.base_dir == actual.base_dir
|
||||
assert expected.connection_string == actual.connection_string
|
||||
assert expected.container_name == actual.container_name
|
||||
assert expected.storage_account_blob_url == actual.storage_account_blob_url
|
||||
assert expected.cosmosdb_account_url == actual.cosmosdb_account_url
|
||||
|
||||
|
||||
def assert_cache_configs(actual: CacheConfig, expected: CacheConfig) -> None:
|
||||
assert actual.type == expected.type
|
||||
assert actual.base_dir == expected.base_dir
|
||||
assert actual.connection_string == expected.connection_string
|
||||
assert actual.container_name == expected.container_name
|
||||
assert actual.storage_account_blob_url == expected.storage_account_blob_url
|
||||
assert actual.cosmosdb_account_url == expected.cosmosdb_account_url
|
||||
|
||||
|
||||
def assert_input_configs(actual: InputConfig, expected: InputConfig) -> None:
|
||||
assert actual.type == expected.type
|
||||
assert actual.file_type == expected.file_type
|
||||
assert actual.base_dir == expected.base_dir
|
||||
assert actual.connection_string == expected.connection_string
|
||||
assert actual.storage_account_blob_url == expected.storage_account_blob_url
|
||||
assert actual.container_name == expected.container_name
|
||||
assert actual.encoding == expected.encoding
|
||||
assert actual.file_pattern == expected.file_pattern
|
||||
assert actual.file_filter == expected.file_filter
|
||||
assert actual.source_column == expected.source_column
|
||||
assert actual.timestamp_column == expected.timestamp_column
|
||||
assert actual.timestamp_format == expected.timestamp_format
|
||||
assert actual.text_column == expected.text_column
|
||||
assert actual.title_column == expected.title_column
|
||||
assert actual.document_attribute_columns == expected.document_attribute_columns
|
||||
|
||||
|
||||
def assert_embed_graph_configs(
|
||||
actual: EmbedGraphConfig, expected: EmbedGraphConfig
|
||||
) -> None:
|
||||
assert actual.enabled == expected.enabled
|
||||
assert actual.dimensions == expected.dimensions
|
||||
assert actual.num_walks == expected.num_walks
|
||||
assert actual.walk_length == expected.walk_length
|
||||
assert actual.window_size == expected.window_size
|
||||
assert actual.iterations == expected.iterations
|
||||
assert actual.random_seed == expected.random_seed
|
||||
assert actual.use_lcc == expected.use_lcc
|
||||
|
||||
|
||||
def assert_text_embedding_configs(
|
||||
actual: TextEmbeddingConfig, expected: TextEmbeddingConfig
|
||||
) -> None:
|
||||
assert actual.batch_size == expected.batch_size
|
||||
assert actual.batch_max_tokens == expected.batch_max_tokens
|
||||
assert actual.target == expected.target
|
||||
assert actual.names == expected.names
|
||||
assert actual.strategy == expected.strategy
|
||||
assert actual.model_id == expected.model_id
|
||||
|
||||
|
||||
def assert_chunking_configs(actual: ChunkingConfig, expected: ChunkingConfig) -> None:
|
||||
assert actual.size == expected.size
|
||||
assert actual.overlap == expected.overlap
|
||||
assert actual.group_by_columns == expected.group_by_columns
|
||||
assert actual.strategy == expected.strategy
|
||||
assert actual.encoding_model == expected.encoding_model
|
||||
|
||||
|
||||
def assert_snapshots_configs(
|
||||
actual: SnapshotsConfig, expected: SnapshotsConfig
|
||||
) -> None:
|
||||
assert actual.embeddings == expected.embeddings
|
||||
assert actual.graphml == expected.graphml
|
||||
assert actual.transient == expected.transient
|
||||
|
||||
|
||||
def assert_entity_extraction_configs(
|
||||
actual: EntityExtractionConfig, expected: EntityExtractionConfig
|
||||
) -> None:
|
||||
assert actual.prompt == expected.prompt
|
||||
assert actual.entity_types == expected.entity_types
|
||||
assert actual.max_gleanings == expected.max_gleanings
|
||||
assert actual.strategy == expected.strategy
|
||||
assert actual.encoding_model == expected.encoding_model
|
||||
assert actual.model_id == expected.model_id
|
||||
|
||||
|
||||
def assert_summarize_descriptions_configs(
|
||||
actual: SummarizeDescriptionsConfig, expected: SummarizeDescriptionsConfig
|
||||
) -> None:
|
||||
assert actual.prompt == expected.prompt
|
||||
assert actual.max_length == expected.max_length
|
||||
assert actual.strategy == expected.strategy
|
||||
assert actual.model_id == expected.model_id
|
||||
|
||||
|
||||
def assert_community_reports_configs(
|
||||
actual: CommunityReportsConfig, expected: CommunityReportsConfig
|
||||
) -> None:
|
||||
assert actual.prompt == expected.prompt
|
||||
assert actual.max_length == expected.max_length
|
||||
assert actual.max_input_length == expected.max_input_length
|
||||
assert actual.strategy == expected.strategy
|
||||
assert actual.model_id == expected.model_id
|
||||
|
||||
|
||||
def assert_claim_extraction_configs(
|
||||
actual: ClaimExtractionConfig, expected: ClaimExtractionConfig
|
||||
) -> None:
|
||||
assert actual.enabled == expected.enabled
|
||||
assert actual.prompt == expected.prompt
|
||||
assert actual.description == expected.description
|
||||
assert actual.max_gleanings == expected.max_gleanings
|
||||
assert actual.strategy == expected.strategy
|
||||
assert actual.encoding_model == expected.encoding_model
|
||||
assert actual.model_id == expected.model_id
|
||||
|
||||
|
||||
def assert_cluster_graph_configs(
|
||||
actual: ClusterGraphConfig, expected: ClusterGraphConfig
|
||||
) -> None:
|
||||
assert actual.max_cluster_size == expected.max_cluster_size
|
||||
assert actual.use_lcc == expected.use_lcc
|
||||
assert actual.seed == expected.seed
|
||||
|
||||
|
||||
def assert_umap_configs(actual: UmapConfig, expected: UmapConfig) -> None:
|
||||
assert actual.enabled == expected.enabled
|
||||
|
||||
|
||||
def assert_local_search_configs(
|
||||
actual: LocalSearchConfig, expected: LocalSearchConfig
|
||||
) -> None:
|
||||
assert actual.prompt == expected.prompt
|
||||
assert actual.text_unit_prop == expected.text_unit_prop
|
||||
assert actual.community_prop == expected.community_prop
|
||||
assert (
|
||||
actual.conversation_history_max_turns == expected.conversation_history_max_turns
|
||||
)
|
||||
assert actual.top_k_entities == expected.top_k_entities
|
||||
assert actual.top_k_relationships == expected.top_k_relationships
|
||||
assert actual.temperature == expected.temperature
|
||||
assert actual.top_p == expected.top_p
|
||||
assert actual.n == expected.n
|
||||
assert actual.max_tokens == expected.max_tokens
|
||||
assert actual.llm_max_tokens == expected.llm_max_tokens
|
||||
|
||||
|
||||
def assert_global_search_configs(
|
||||
actual: GlobalSearchConfig, expected: GlobalSearchConfig
|
||||
) -> None:
|
||||
assert actual.map_prompt == expected.map_prompt
|
||||
assert actual.reduce_prompt == expected.reduce_prompt
|
||||
assert actual.knowledge_prompt == expected.knowledge_prompt
|
||||
assert actual.temperature == expected.temperature
|
||||
assert actual.top_p == expected.top_p
|
||||
assert actual.n == expected.n
|
||||
assert actual.max_tokens == expected.max_tokens
|
||||
assert actual.data_max_tokens == expected.data_max_tokens
|
||||
assert actual.map_max_tokens == expected.map_max_tokens
|
||||
assert actual.reduce_max_tokens == expected.reduce_max_tokens
|
||||
assert actual.concurrency == expected.concurrency
|
||||
assert actual.dynamic_search_llm == expected.dynamic_search_llm
|
||||
assert actual.dynamic_search_threshold == expected.dynamic_search_threshold
|
||||
assert actual.dynamic_search_keep_parent == expected.dynamic_search_keep_parent
|
||||
assert actual.dynamic_search_num_repeats == expected.dynamic_search_num_repeats
|
||||
assert actual.dynamic_search_use_summary == expected.dynamic_search_use_summary
|
||||
assert (
|
||||
actual.dynamic_search_concurrent_coroutines
|
||||
== expected.dynamic_search_concurrent_coroutines
|
||||
)
|
||||
assert actual.dynamic_search_max_level == expected.dynamic_search_max_level
|
||||
|
||||
|
||||
def assert_drift_search_configs(
|
||||
actual: DRIFTSearchConfig, expected: DRIFTSearchConfig
|
||||
) -> None:
|
||||
assert actual.prompt == expected.prompt
|
||||
assert actual.temperature == expected.temperature
|
||||
assert actual.top_p == expected.top_p
|
||||
assert actual.n == expected.n
|
||||
assert actual.max_tokens == expected.max_tokens
|
||||
assert actual.data_max_tokens == expected.data_max_tokens
|
||||
assert actual.concurrency == expected.concurrency
|
||||
assert actual.drift_k_followups == expected.drift_k_followups
|
||||
assert actual.primer_folds == expected.primer_folds
|
||||
assert actual.primer_llm_max_tokens == expected.primer_llm_max_tokens
|
||||
assert actual.n_depth == expected.n_depth
|
||||
assert actual.local_search_text_unit_prop == expected.local_search_text_unit_prop
|
||||
assert actual.local_search_community_prop == expected.local_search_community_prop
|
||||
assert (
|
||||
actual.local_search_top_k_mapped_entities
|
||||
== expected.local_search_top_k_mapped_entities
|
||||
)
|
||||
assert (
|
||||
actual.local_search_top_k_relationships
|
||||
== expected.local_search_top_k_relationships
|
||||
)
|
||||
assert actual.local_search_max_data_tokens == expected.local_search_max_data_tokens
|
||||
assert actual.local_search_temperature == expected.local_search_temperature
|
||||
assert actual.local_search_top_p == expected.local_search_top_p
|
||||
assert actual.local_search_n == expected.local_search_n
|
||||
assert (
|
||||
actual.local_search_llm_max_gen_tokens
|
||||
== expected.local_search_llm_max_gen_tokens
|
||||
)
|
||||
|
||||
|
||||
def assert_basic_search_configs(
|
||||
actual: BasicSearchConfig, expected: BasicSearchConfig
|
||||
) -> None:
|
||||
assert actual.prompt == expected.prompt
|
||||
assert actual.text_unit_prop == expected.text_unit_prop
|
||||
assert (
|
||||
actual.conversation_history_max_turns == expected.conversation_history_max_turns
|
||||
)
|
||||
assert actual.temperature == expected.temperature
|
||||
assert actual.top_p == expected.top_p
|
||||
assert actual.n == expected.n
|
||||
assert actual.max_tokens == expected.max_tokens
|
||||
assert actual.llm_max_tokens == expected.llm_max_tokens
|
||||
|
||||
|
||||
def assert_graphrag_configs(actual: GraphRagConfig, expected: GraphRagConfig) -> None:
|
||||
assert actual.root_dir == expected.root_dir
|
||||
|
||||
a_keys = sorted(actual.models.keys())
|
||||
e_keys = sorted(expected.models.keys())
|
||||
assert len(a_keys) == len(e_keys)
|
||||
for a, e in zip(a_keys, e_keys, strict=False):
|
||||
assert a == e
|
||||
assert_language_model_configs(actual.models[a], expected.models[e])
|
||||
|
||||
assert_vector_store_configs(actual.vector_store, expected.vector_store)
|
||||
assert_reporting_configs(actual.reporting, expected.reporting)
|
||||
assert_output_configs(actual.output, expected.output)
|
||||
|
||||
if actual.update_index_output is not None:
|
||||
assert expected.update_index_output is not None
|
||||
assert_output_configs(actual.update_index_output, expected.update_index_output)
|
||||
else:
|
||||
assert expected.update_index_output is None
|
||||
|
||||
assert_cache_configs(actual.cache, expected.cache)
|
||||
assert_input_configs(actual.input, expected.input)
|
||||
assert_embed_graph_configs(actual.embed_graph, expected.embed_graph)
|
||||
assert_text_embedding_configs(actual.embeddings, expected.embeddings)
|
||||
assert_chunking_configs(actual.chunks, expected.chunks)
|
||||
assert_snapshots_configs(actual.snapshots, expected.snapshots)
|
||||
assert_entity_extraction_configs(
|
||||
actual.entity_extraction, expected.entity_extraction
|
||||
)
|
||||
assert_summarize_descriptions_configs(
|
||||
actual.summarize_descriptions, expected.summarize_descriptions
|
||||
)
|
||||
assert_community_reports_configs(
|
||||
actual.community_reports, expected.community_reports
|
||||
)
|
||||
assert_claim_extraction_configs(actual.claim_extraction, expected.claim_extraction)
|
||||
assert_cluster_graph_configs(actual.cluster_graph, expected.cluster_graph)
|
||||
assert_umap_configs(actual.umap, expected.umap)
|
||||
assert_local_search_configs(actual.local_search, expected.local_search)
|
||||
assert_global_search_configs(actual.global_search, expected.global_search)
|
||||
assert_drift_search_configs(actual.drift_search, expected.drift_search)
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.utils.embeddings import create_collection_name
|
||||
from graphrag.config.embeddings import create_collection_name
|
||||
|
||||
|
||||
def test_create_collection_name():
|
||||
|
||||
@ -7,6 +7,7 @@ from graphrag.index.workflows.compute_communities import run_workflow
|
||||
from graphrag.utils.storage import load_table_from_storage
|
||||
|
||||
from .util import (
|
||||
DEFAULT_MODEL_CONFIG,
|
||||
compare_outputs,
|
||||
create_test_context,
|
||||
load_test_table,
|
||||
@ -20,7 +21,7 @@ async def test_compute_communities():
|
||||
storage=["base_relationship_edges"],
|
||||
)
|
||||
|
||||
config = create_graphrag_config()
|
||||
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
|
||||
@ -7,6 +7,7 @@ from graphrag.index.workflows.create_base_text_units import run_workflow, workfl
|
||||
from graphrag.utils.storage import load_table_from_storage
|
||||
|
||||
from .util import (
|
||||
DEFAULT_MODEL_CONFIG,
|
||||
compare_outputs,
|
||||
create_test_context,
|
||||
load_test_table,
|
||||
@ -18,7 +19,7 @@ async def test_create_base_text_units():
|
||||
|
||||
context = await create_test_context()
|
||||
|
||||
config = create_graphrag_config()
|
||||
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
# test data was created with 4o, so we need to match the encoding for chunks to be identical
|
||||
config.chunks.encoding_model = "o200k_base"
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from graphrag.index.workflows.create_final_communities import (
|
||||
from graphrag.utils.storage import load_table_from_storage
|
||||
|
||||
from .util import (
|
||||
DEFAULT_MODEL_CONFIG,
|
||||
compare_outputs,
|
||||
create_test_context,
|
||||
load_test_table,
|
||||
@ -27,7 +28,7 @@ async def test_create_final_communities():
|
||||
],
|
||||
)
|
||||
|
||||
config = create_graphrag_config()
|
||||
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
|
||||
@ -2,8 +2,6 @@
|
||||
# Licensed under the MIT License
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.config.enums import LLMType
|
||||
@ -11,7 +9,6 @@ from graphrag.index.operations.summarize_communities.community_reports_extractor
|
||||
CommunityReportResponse,
|
||||
FindingModel,
|
||||
)
|
||||
from graphrag.index.run.derive_from_rows import ParallelizationError
|
||||
from graphrag.index.workflows.create_final_community_reports import (
|
||||
run_workflow,
|
||||
workflow_name,
|
||||
@ -19,6 +16,7 @@ from graphrag.index.workflows.create_final_community_reports import (
|
||||
from graphrag.utils.storage import load_table_from_storage
|
||||
|
||||
from .util import (
|
||||
DEFAULT_MODEL_CONFIG,
|
||||
compare_outputs,
|
||||
create_test_context,
|
||||
load_test_table,
|
||||
@ -41,12 +39,6 @@ MOCK_RESPONSES = [
|
||||
)
|
||||
]
|
||||
|
||||
MOCK_LLM_CONFIG = {
|
||||
"type": LLMType.StaticResponse,
|
||||
"responses": MOCK_RESPONSES,
|
||||
"parse_json": True,
|
||||
}
|
||||
|
||||
|
||||
async def test_create_final_community_reports():
|
||||
expected = load_test_table(workflow_name)
|
||||
@ -61,10 +53,16 @@ async def test_create_final_community_reports():
|
||||
]
|
||||
)
|
||||
|
||||
config = create_graphrag_config()
|
||||
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
llm_settings = config.get_language_model_config(
|
||||
config.community_reports.model_id
|
||||
).model_dump()
|
||||
llm_settings["type"] = LLMType.StaticResponse
|
||||
llm_settings["responses"] = MOCK_RESPONSES
|
||||
llm_settings["parse_json"] = True
|
||||
config.community_reports.strategy = {
|
||||
"type": "graph_intelligence",
|
||||
"llm": MOCK_LLM_CONFIG,
|
||||
"llm": llm_settings,
|
||||
}
|
||||
|
||||
await run_workflow(
|
||||
@ -83,27 +81,3 @@ async def test_create_final_community_reports():
|
||||
# assert a handful of mock data items to confirm they get put in the right spot
|
||||
assert actual["rank"][:1][0] == 2
|
||||
assert actual["rank_explanation"][:1][0] == "<rating_explanation>"
|
||||
|
||||
|
||||
async def test_create_final_community_reports_missing_llm_throws():
|
||||
context = await create_test_context(
|
||||
storage=[
|
||||
"create_final_nodes",
|
||||
"create_final_covariates",
|
||||
"create_final_relationships",
|
||||
"create_final_entities",
|
||||
"create_final_communities",
|
||||
]
|
||||
)
|
||||
|
||||
config = create_graphrag_config()
|
||||
config.community_reports.strategy = {
|
||||
"type": "graph_intelligence",
|
||||
}
|
||||
|
||||
with pytest.raises(ParallelizationError):
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
@ -1,13 +1,11 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
import pytest
|
||||
from pandas.testing import assert_series_equal
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.config.enums import LLMType
|
||||
from graphrag.index.run.derive_from_rows import ParallelizationError
|
||||
from graphrag.index.workflows.create_final_covariates import (
|
||||
run_workflow,
|
||||
workflow_name,
|
||||
@ -15,6 +13,7 @@ from graphrag.index.workflows.create_final_covariates import (
|
||||
from graphrag.utils.storage import load_table_from_storage
|
||||
|
||||
from .util import (
|
||||
DEFAULT_MODEL_CONFIG,
|
||||
create_test_context,
|
||||
load_test_table,
|
||||
)
|
||||
@ -25,8 +24,6 @@ MOCK_LLM_RESPONSES = [
|
||||
""".strip()
|
||||
]
|
||||
|
||||
MOCK_LLM_CONFIG = {"type": LLMType.StaticResponse, "responses": MOCK_LLM_RESPONSES}
|
||||
|
||||
|
||||
async def test_create_final_covariates():
|
||||
input = load_test_table("create_base_text_units")
|
||||
@ -36,10 +33,15 @@ async def test_create_final_covariates():
|
||||
storage=["create_base_text_units"],
|
||||
)
|
||||
|
||||
config = create_graphrag_config()
|
||||
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
llm_settings = config.get_language_model_config(
|
||||
config.claim_extraction.model_id
|
||||
).model_dump()
|
||||
llm_settings["type"] = LLMType.StaticResponse
|
||||
llm_settings["responses"] = MOCK_LLM_RESPONSES
|
||||
config.claim_extraction.strategy = {
|
||||
"type": "graph_intelligence",
|
||||
"llm": MOCK_LLM_CONFIG,
|
||||
"llm": llm_settings,
|
||||
"claim_description": "description",
|
||||
}
|
||||
|
||||
@ -78,22 +80,3 @@ async def test_create_final_covariates():
|
||||
actual["source_text"][0]
|
||||
== "According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B."
|
||||
)
|
||||
|
||||
|
||||
async def test_create_final_covariates_missing_llm_throws():
|
||||
context = await create_test_context(
|
||||
storage=["create_base_text_units"],
|
||||
)
|
||||
|
||||
config = create_graphrag_config()
|
||||
config.claim_extraction.strategy = {
|
||||
"type": "graph_intelligence",
|
||||
"claim_description": "description",
|
||||
}
|
||||
|
||||
with pytest.raises(ParallelizationError):
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
@ -10,6 +10,7 @@ from graphrag.index.workflows.create_final_documents import (
|
||||
from graphrag.utils.storage import load_table_from_storage
|
||||
|
||||
from .util import (
|
||||
DEFAULT_MODEL_CONFIG,
|
||||
compare_outputs,
|
||||
create_test_context,
|
||||
load_test_table,
|
||||
@ -23,7 +24,7 @@ async def test_create_final_documents():
|
||||
storage=["create_base_text_units"],
|
||||
)
|
||||
|
||||
config = create_graphrag_config()
|
||||
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
@ -43,7 +44,7 @@ async def test_create_final_documents_with_attribute_columns():
|
||||
storage=["create_base_text_units"],
|
||||
)
|
||||
|
||||
config = create_graphrag_config()
|
||||
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
config.input.document_attribute_columns = ["title"]
|
||||
|
||||
await run_workflow(
|
||||
|
||||
@ -10,6 +10,7 @@ from graphrag.index.workflows.create_final_entities import (
|
||||
from graphrag.utils.storage import load_table_from_storage
|
||||
|
||||
from .util import (
|
||||
DEFAULT_MODEL_CONFIG,
|
||||
compare_outputs,
|
||||
create_test_context,
|
||||
load_test_table,
|
||||
@ -23,7 +24,7 @@ async def test_create_final_entities():
|
||||
storage=["base_entity_nodes"],
|
||||
)
|
||||
|
||||
config = create_graphrag_config()
|
||||
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
|
||||
@ -10,6 +10,7 @@ from graphrag.index.workflows.create_final_nodes import (
|
||||
from graphrag.utils.storage import load_table_from_storage
|
||||
|
||||
from .util import (
|
||||
DEFAULT_MODEL_CONFIG,
|
||||
compare_outputs,
|
||||
create_test_context,
|
||||
load_test_table,
|
||||
@ -27,7 +28,7 @@ async def test_create_final_nodes():
|
||||
],
|
||||
)
|
||||
|
||||
config = create_graphrag_config()
|
||||
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
|
||||
@ -11,6 +11,7 @@ from graphrag.index.workflows.create_final_relationships import (
|
||||
from graphrag.utils.storage import load_table_from_storage
|
||||
|
||||
from .util import (
|
||||
DEFAULT_MODEL_CONFIG,
|
||||
compare_outputs,
|
||||
create_test_context,
|
||||
load_test_table,
|
||||
@ -24,7 +25,7 @@ async def test_create_final_relationships():
|
||||
storage=["base_relationship_edges"],
|
||||
)
|
||||
|
||||
config = create_graphrag_config()
|
||||
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user