diff --git a/docs/examples_notebooks/api_overview.ipynb b/docs/examples_notebooks/api_overview.ipynb index abcd7832..2a0c0f15 100644 --- a/docs/examples_notebooks/api_overview.ipynb +++ b/docs/examples_notebooks/api_overview.ipynb @@ -28,11 +28,10 @@ "from pathlib import Path\n", "from pprint import pprint\n", "\n", + "import graphrag.api as api\n", "import pandas as pd\n", "from graphrag.config.load_config import load_config\n", - "from graphrag.index.typing.pipeline_run_result import PipelineRunResult\n", - "\n", - "import graphrag.api as api" + "from graphrag.index.typing.pipeline_run_result import PipelineRunResult" ] }, { diff --git a/docs/examples_notebooks/input_documents.ipynb b/docs/examples_notebooks/input_documents.ipynb index 505c0fe1..5657770e 100644 --- a/docs/examples_notebooks/input_documents.ipynb +++ b/docs/examples_notebooks/input_documents.ipynb @@ -30,11 +30,10 @@ "from pathlib import Path\n", "from pprint import pprint\n", "\n", + "import graphrag.api as api\n", "import pandas as pd\n", "from graphrag.config.load_config import load_config\n", - "from graphrag.index.typing.pipeline_run_result import PipelineRunResult\n", - "\n", - "import graphrag.api as api" + "from graphrag.index.typing.pipeline_run_result import PipelineRunResult" ] }, { diff --git a/packages/graphrag/graphrag/cli/index.py b/packages/graphrag/graphrag/cli/index.py index 60b724ea..94a40cde 100644 --- a/packages/graphrag/graphrag/cli/index.py +++ b/packages/graphrag/graphrag/cli/index.py @@ -9,9 +9,11 @@ import sys import warnings from pathlib import Path +from graphrag_cache.cache_type import CacheType + import graphrag.api as api from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks -from graphrag.config.enums import CacheType, IndexingMethod +from graphrag.config.enums import IndexingMethod from graphrag.config.load_config import load_config from graphrag.index.validate_config import validate_config_names from graphrag.utils.cli import redact @@ -102,7 +104,7 @@ def _run_index( ) if not cache: - config.cache.type = CacheType.none + config.cache.type = CacheType.Noop if not skip_validation: validate_config_names(config) diff --git a/packages/graphrag/graphrag/config/defaults.py b/packages/graphrag/graphrag/config/defaults.py index e330c824..2edf1f90 100644 --- a/packages/graphrag/graphrag/config/defaults.py +++ b/packages/graphrag/graphrag/config/defaults.py @@ -15,12 +15,12 @@ from graphrag.config.enums import ( AsyncType, AuthType, ChunkStrategyType, - InputFileType, ModelType, NounPhraseExtractorType, ReportingType, VectorStoreType, ) +from graphrag.index.input.input_file_type import InputFileType from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import ( EN_STOP_WORDS, ) @@ -241,7 +241,7 @@ class InputDefaults: """Default values for input.""" storage: InputStorageDefaults = field(default_factory=InputStorageDefaults) - file_type: ClassVar[InputFileType] = InputFileType.text + file_type: ClassVar[InputFileType] = InputFileType.Text encoding: str = "utf-8" file_pattern: str = "" text_column: str = "text" diff --git a/packages/graphrag/graphrag/config/enums.py b/packages/graphrag/graphrag/config/enums.py index a5b6b74a..b615e547 100644 --- a/packages/graphrag/graphrag/config/enums.py +++ b/packages/graphrag/graphrag/config/enums.py @@ -5,7 +5,7 @@ from __future__ import annotations -from enum import Enum, StrEnum +from enum import Enum class CacheType(str, Enum): @@ -27,38 +27,6 @@ class CacheType(str, Enum): return f'"{self.value}"' -class InputFileType(StrEnum): - """The input file type for the pipeline.""" - - csv = "csv" - """The CSV input type.""" - text = "text" - """The text input type.""" - json = "json" - """The JSON input type.""" - - def __repr__(self): - """Get a string representation.""" - return f'"{self.value}"' - - -class StorageType(str, Enum): - """The output type for the pipeline.""" - - file = "file" - """The file output type.""" - memory = "memory" - """The memory output type.""" - blob = "blob" - """The blob output type.""" - cosmosdb = "cosmosdb" - """The cosmosdb output type""" - - def __repr__(self): - """Get a string representation.""" - return f'"{self.value}"' - - class VectorStoreType(str, Enum): """The supported vector store types.""" diff --git a/packages/graphrag/graphrag/config/models/graph_rag_config.py b/packages/graphrag/graphrag/config/models/graph_rag_config.py index 9a1ac89b..10a2f41d 100644 --- a/packages/graphrag/graphrag/config/models/graph_rag_config.py +++ b/packages/graphrag/graphrag/config/models/graph_rag_config.py @@ -24,7 +24,6 @@ from graphrag.config.models.extract_claims_config import ExtractClaimsConfig from graphrag.config.models.extract_graph_config import ExtractGraphConfig from graphrag.config.models.extract_graph_nlp_config import ExtractGraphNLPConfig from graphrag.config.models.global_search_config import GlobalSearchConfig -from graphrag.config.models.input_config import InputConfig from graphrag.config.models.language_model_config import LanguageModelConfig from graphrag.config.models.local_search_config import LocalSearchConfig from graphrag.config.models.prune_graph_config import PruneGraphConfig @@ -34,6 +33,7 @@ from graphrag.config.models.summarize_descriptions_config import ( SummarizeDescriptionsConfig, ) from graphrag.config.models.vector_store_config import VectorStoreConfig +from graphrag.index.input.input_config import InputConfig from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter_factory import ( RateLimiterFactory, ) @@ -110,7 +110,7 @@ class GraphRagConfig(BaseModel): def _validate_input_pattern(self) -> None: """Validate the input file pattern based on the specified type.""" if len(self.input.file_pattern) == 0: - if self.input.file_type == defs.InputFileType.text: + if self.input.file_type == defs.InputFileType.Text: self.input.file_pattern = ".*\\.txt$" else: self.input.file_pattern = f".*\\.{self.input.file_type}$" diff --git a/packages/graphrag/graphrag/config/models/input_config.py b/packages/graphrag/graphrag/index/input/input_config.py similarity index 93% rename from packages/graphrag/graphrag/config/models/input_config.py rename to packages/graphrag/graphrag/index/input/input_config.py index e3ebb9cd..7a86075f 100644 --- a/packages/graphrag/graphrag/config/models/input_config.py +++ b/packages/graphrag/graphrag/index/input/input_config.py @@ -14,7 +14,7 @@ class InputConfig(BaseModel): """The default configuration section for Input.""" model_config = ConfigDict(extra="allow") - """Allow extra fields to support custom cache implementations.""" + """Allow extra fields to support custom reader implementations.""" storage: StorageConfig = Field( description="The storage configuration to use for reading input documents.", @@ -34,7 +34,7 @@ class InputConfig(BaseModel): description="The input file pattern to use.", default=graphrag_config_defaults.input.file_pattern, ) - text_column: str = Field( + text_column: str | None = Field( description="The input text column to use.", default=graphrag_config_defaults.input.text_column, ) diff --git a/packages/graphrag/graphrag/index/input/input_file_type.py b/packages/graphrag/graphrag/index/input/input_file_type.py new file mode 100644 index 00000000..3996eee3 --- /dev/null +++ b/packages/graphrag/graphrag/index/input/input_file_type.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing input file type enum.""" + +from enum import StrEnum + + +class InputFileType(StrEnum): + """The input file type for the pipeline.""" + + Csv = "csv" + """The CSV input type.""" + Text = "text" + """The text input type.""" + Json = "json" + """The JSON input type.""" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' diff --git a/packages/graphrag/graphrag/index/input/input_reader_factory.py b/packages/graphrag/graphrag/index/input/input_reader_factory.py index a548ad47..e0b391d7 100644 --- a/packages/graphrag/graphrag/index/input/input_reader_factory.py +++ b/packages/graphrag/graphrag/index/input/input_reader_factory.py @@ -10,8 +10,8 @@ from graphrag_common.factory import Factory from graphrag_common.factory.factory import ServiceScope from graphrag_storage.storage import Storage -from graphrag.config.enums import InputFileType -from graphrag.config.models.input_config import InputConfig +from graphrag.index.input.input_config import InputConfig +from graphrag.index.input.input_file_type import InputFileType from graphrag.index.input.input_reader import InputReader logger = logging.getLogger(__name__) @@ -61,18 +61,18 @@ def create_input_reader(config: InputConfig, storage: Storage) -> InputReader: if input_strategy not in input_reader_factory: match input_strategy: - case InputFileType.csv: + case InputFileType.Csv: from graphrag.index.input.csv import CSVFileReader - register_input_reader(InputFileType.csv, CSVFileReader) - case InputFileType.text: + register_input_reader(InputFileType.Csv, CSVFileReader) + case InputFileType.Text: from graphrag.index.input.text import TextFileReader - register_input_reader(InputFileType.text, TextFileReader) - case InputFileType.json: + register_input_reader(InputFileType.Text, TextFileReader) + case InputFileType.Json: from graphrag.index.input.json import JSONFileReader - register_input_reader(InputFileType.json, JSONFileReader) + register_input_reader(InputFileType.Json, JSONFileReader) case _: msg = f"InputConfig.file_type '{input_strategy}' is not registered in the InputReaderFactory. Registered types: {', '.join(input_reader_factory.keys())}." raise ValueError(msg) diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index 1315907e..3f505275 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -18,7 +18,6 @@ from graphrag.config.models.extract_graph_nlp_config import ( ) from graphrag.config.models.global_search_config import GlobalSearchConfig from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.config.models.input_config import InputConfig from graphrag.config.models.language_model_config import LanguageModelConfig from graphrag.config.models.local_search_config import LocalSearchConfig from graphrag.config.models.prune_graph_config import PruneGraphConfig @@ -28,6 +27,7 @@ from graphrag.config.models.summarize_descriptions_config import ( SummarizeDescriptionsConfig, ) from graphrag.config.models.vector_store_config import VectorStoreConfig +from graphrag.index.input.input_config import InputConfig from graphrag_cache import CacheConfig from graphrag_storage import StorageConfig from pydantic import BaseModel diff --git a/tests/unit/indexing/input/test_csv_loader.py b/tests/unit/indexing/input/test_csv_loader.py index 938ce4ca..f4c1c3cf 100644 --- a/tests/unit/indexing/input/test_csv_loader.py +++ b/tests/unit/indexing/input/test_csv_loader.py @@ -1,8 +1,8 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.config.enums import InputFileType -from graphrag.config.models.input_config import InputConfig +from graphrag.index.input.input_config import InputConfig +from graphrag.index.input.input_file_type import InputFileType from graphrag.index.input.input_reader_factory import ( create_input_reader, ) @@ -14,7 +14,7 @@ async def test_csv_loader_one_file(): storage=StorageConfig( base_dir="tests/unit/indexing/input/data/one-csv", ), - file_type=InputFileType.csv, + file_type=InputFileType.Csv, file_pattern=".*\\.csv$", ) storage = create_storage(config.storage) @@ -29,7 +29,7 @@ async def test_csv_loader_one_file_with_title(): storage=StorageConfig( base_dir="tests/unit/indexing/input/data/one-csv", ), - file_type=InputFileType.csv, + file_type=InputFileType.Csv, file_pattern=".*\\.csv$", title_column="title", ) @@ -45,7 +45,7 @@ async def test_csv_loader_one_file_with_metadata(): storage=StorageConfig( base_dir="tests/unit/indexing/input/data/one-csv", ), - file_type=InputFileType.csv, + file_type=InputFileType.Csv, file_pattern=".*\\.csv$", title_column="title", metadata=["title"], @@ -62,7 +62,7 @@ async def test_csv_loader_multiple_files(): storage=StorageConfig( base_dir="tests/unit/indexing/input/data/multiple-csvs", ), - file_type=InputFileType.csv, + file_type=InputFileType.Csv, file_pattern=".*\\.csv$", ) storage = create_storage(config.storage) diff --git a/tests/unit/indexing/input/test_json_loader.py b/tests/unit/indexing/input/test_json_loader.py index d5cb0044..d604ca9c 100644 --- a/tests/unit/indexing/input/test_json_loader.py +++ b/tests/unit/indexing/input/test_json_loader.py @@ -1,8 +1,8 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.config.enums import InputFileType -from graphrag.config.models.input_config import InputConfig +from graphrag.index.input.input_config import InputConfig +from graphrag.index.input.input_file_type import InputFileType from graphrag.index.input.input_reader_factory import create_input_reader from graphrag_storage import StorageConfig, create_storage @@ -12,7 +12,7 @@ async def test_json_loader_one_file_one_object(): storage=StorageConfig( base_dir="tests/unit/indexing/input/data/one-json-one-object", ), - file_type=InputFileType.json, + file_type=InputFileType.Json, file_pattern=".*\\.json$", ) storage = create_storage(config.storage) @@ -27,7 +27,7 @@ async def test_json_loader_one_file_multiple_objects(): storage=StorageConfig( base_dir="tests/unit/indexing/input/data/one-json-multiple-objects", ), - file_type=InputFileType.json, + file_type=InputFileType.Json, file_pattern=".*\\.json$", ) storage = create_storage(config.storage) @@ -43,7 +43,7 @@ async def test_json_loader_one_file_with_title(): storage=StorageConfig( base_dir="tests/unit/indexing/input/data/one-json-one-object", ), - file_type=InputFileType.json, + file_type=InputFileType.Json, file_pattern=".*\\.json$", title_column="title", ) @@ -59,7 +59,7 @@ async def test_json_loader_one_file_with_metadata(): storage=StorageConfig( base_dir="tests/unit/indexing/input/data/one-json-one-object", ), - file_type=InputFileType.json, + file_type=InputFileType.Json, file_pattern=".*\\.json$", title_column="title", metadata=["title"], @@ -76,7 +76,7 @@ async def test_json_loader_multiple_files(): storage=StorageConfig( base_dir="tests/unit/indexing/input/data/multiple-jsons", ), - file_type=InputFileType.json, + file_type=InputFileType.Json, file_pattern=".*\\.json$", ) storage = create_storage(config.storage) diff --git a/tests/unit/indexing/input/test_txt_loader.py b/tests/unit/indexing/input/test_txt_loader.py index 0395741d..67d07646 100644 --- a/tests/unit/indexing/input/test_txt_loader.py +++ b/tests/unit/indexing/input/test_txt_loader.py @@ -1,8 +1,8 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.config.enums import InputFileType -from graphrag.config.models.input_config import InputConfig +from graphrag.index.input.input_config import InputConfig +from graphrag.index.input.input_file_type import InputFileType from graphrag.index.input.input_reader_factory import create_input_reader from graphrag_storage import StorageConfig, create_storage @@ -12,7 +12,7 @@ async def test_txt_loader_one_file(): storage=StorageConfig( base_dir="tests/unit/indexing/input/data/one-txt", ), - file_type=InputFileType.text, + file_type=InputFileType.Text, file_pattern=".*\\.txt$", ) storage = create_storage(config.storage) @@ -27,7 +27,7 @@ async def test_txt_loader_one_file_with_metadata(): storage=StorageConfig( base_dir="tests/unit/indexing/input/data/one-txt", ), - file_type=InputFileType.text, + file_type=InputFileType.Text, file_pattern=".*\\.txt$", metadata=["title"], ) @@ -44,7 +44,7 @@ async def test_txt_loader_multiple_files(): storage=StorageConfig( base_dir="tests/unit/indexing/input/data/multiple-txts", ), - file_type=InputFileType.text, + file_type=InputFileType.Text, file_pattern=".*\\.txt$", ) storage = create_storage(config.storage) diff --git a/unified-search-app/app/app_logic.py b/unified-search-app/app/app_logic.py index a573b9da..dc64e0e7 100644 --- a/unified-search-app/app/app_logic.py +++ b/unified-search-app/app/app_logic.py @@ -7,6 +7,7 @@ import asyncio import logging from typing import TYPE_CHECKING +import graphrag.api as api import streamlit as st from knowledge_loader.data_sources.loader import ( create_datasource, @@ -17,8 +18,6 @@ from rag.typing import SearchResult, SearchType from state.session_variables import SessionVariables from ui.search import display_search_result -import graphrag.api as api - if TYPE_CHECKING: import pandas as pd