mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Move input config alongside input readers
This commit is contained in:
parent
99aea5226e
commit
efaaa1f1d0
@ -28,11 +28,10 @@
|
||||
"from pathlib import Path\n",
|
||||
"from pprint import pprint\n",
|
||||
"\n",
|
||||
"import graphrag.api as api\n",
|
||||
"import pandas as pd\n",
|
||||
"from graphrag.config.load_config import load_config\n",
|
||||
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult\n",
|
||||
"\n",
|
||||
"import graphrag.api as api"
|
||||
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -30,11 +30,10 @@
|
||||
"from pathlib import Path\n",
|
||||
"from pprint import pprint\n",
|
||||
"\n",
|
||||
"import graphrag.api as api\n",
|
||||
"import pandas as pd\n",
|
||||
"from graphrag.config.load_config import load_config\n",
|
||||
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult\n",
|
||||
"\n",
|
||||
"import graphrag.api as api"
|
||||
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -9,9 +9,11 @@ import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag_cache.cache_type import CacheType
|
||||
|
||||
import graphrag.api as api
|
||||
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
|
||||
from graphrag.config.enums import CacheType, IndexingMethod
|
||||
from graphrag.config.enums import IndexingMethod
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.index.validate_config import validate_config_names
|
||||
from graphrag.utils.cli import redact
|
||||
@ -102,7 +104,7 @@ def _run_index(
|
||||
)
|
||||
|
||||
if not cache:
|
||||
config.cache.type = CacheType.none
|
||||
config.cache.type = CacheType.Noop
|
||||
|
||||
if not skip_validation:
|
||||
validate_config_names(config)
|
||||
|
||||
@ -15,12 +15,12 @@ from graphrag.config.enums import (
|
||||
AsyncType,
|
||||
AuthType,
|
||||
ChunkStrategyType,
|
||||
InputFileType,
|
||||
ModelType,
|
||||
NounPhraseExtractorType,
|
||||
ReportingType,
|
||||
VectorStoreType,
|
||||
)
|
||||
from graphrag.index.input.input_file_type import InputFileType
|
||||
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
|
||||
EN_STOP_WORDS,
|
||||
)
|
||||
@ -241,7 +241,7 @@ class InputDefaults:
|
||||
"""Default values for input."""
|
||||
|
||||
storage: InputStorageDefaults = field(default_factory=InputStorageDefaults)
|
||||
file_type: ClassVar[InputFileType] = InputFileType.text
|
||||
file_type: ClassVar[InputFileType] = InputFileType.Text
|
||||
encoding: str = "utf-8"
|
||||
file_pattern: str = ""
|
||||
text_column: str = "text"
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum, StrEnum
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class CacheType(str, Enum):
|
||||
@ -27,38 +27,6 @@ class CacheType(str, Enum):
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
class InputFileType(StrEnum):
|
||||
"""The input file type for the pipeline."""
|
||||
|
||||
csv = "csv"
|
||||
"""The CSV input type."""
|
||||
text = "text"
|
||||
"""The text input type."""
|
||||
json = "json"
|
||||
"""The JSON input type."""
|
||||
|
||||
def __repr__(self):
|
||||
"""Get a string representation."""
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
class StorageType(str, Enum):
|
||||
"""The output type for the pipeline."""
|
||||
|
||||
file = "file"
|
||||
"""The file output type."""
|
||||
memory = "memory"
|
||||
"""The memory output type."""
|
||||
blob = "blob"
|
||||
"""The blob output type."""
|
||||
cosmosdb = "cosmosdb"
|
||||
"""The cosmosdb output type"""
|
||||
|
||||
def __repr__(self):
|
||||
"""Get a string representation."""
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
class VectorStoreType(str, Enum):
|
||||
"""The supported vector store types."""
|
||||
|
||||
|
||||
@ -24,7 +24,6 @@ from graphrag.config.models.extract_claims_config import ExtractClaimsConfig
|
||||
from graphrag.config.models.extract_graph_config import ExtractGraphConfig
|
||||
from graphrag.config.models.extract_graph_nlp_config import ExtractGraphNLPConfig
|
||||
from graphrag.config.models.global_search_config import GlobalSearchConfig
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.config.models.local_search_config import LocalSearchConfig
|
||||
from graphrag.config.models.prune_graph_config import PruneGraphConfig
|
||||
@ -34,6 +33,7 @@ from graphrag.config.models.summarize_descriptions_config import (
|
||||
SummarizeDescriptionsConfig,
|
||||
)
|
||||
from graphrag.config.models.vector_store_config import VectorStoreConfig
|
||||
from graphrag.index.input.input_config import InputConfig
|
||||
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter_factory import (
|
||||
RateLimiterFactory,
|
||||
)
|
||||
@ -110,7 +110,7 @@ class GraphRagConfig(BaseModel):
|
||||
def _validate_input_pattern(self) -> None:
|
||||
"""Validate the input file pattern based on the specified type."""
|
||||
if len(self.input.file_pattern) == 0:
|
||||
if self.input.file_type == defs.InputFileType.text:
|
||||
if self.input.file_type == defs.InputFileType.Text:
|
||||
self.input.file_pattern = ".*\\.txt$"
|
||||
else:
|
||||
self.input.file_pattern = f".*\\.{self.input.file_type}$"
|
||||
|
||||
@ -14,7 +14,7 @@ class InputConfig(BaseModel):
|
||||
"""The default configuration section for Input."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
"""Allow extra fields to support custom cache implementations."""
|
||||
"""Allow extra fields to support custom reader implementations."""
|
||||
|
||||
storage: StorageConfig = Field(
|
||||
description="The storage configuration to use for reading input documents.",
|
||||
@ -34,7 +34,7 @@ class InputConfig(BaseModel):
|
||||
description="The input file pattern to use.",
|
||||
default=graphrag_config_defaults.input.file_pattern,
|
||||
)
|
||||
text_column: str = Field(
|
||||
text_column: str | None = Field(
|
||||
description="The input text column to use.",
|
||||
default=graphrag_config_defaults.input.text_column,
|
||||
)
|
||||
21
packages/graphrag/graphrag/index/input/input_file_type.py
Normal file
21
packages/graphrag/graphrag/index/input/input_file_type.py
Normal file
@ -0,0 +1,21 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing input file type enum."""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class InputFileType(StrEnum):
|
||||
"""The input file type for the pipeline."""
|
||||
|
||||
Csv = "csv"
|
||||
"""The CSV input type."""
|
||||
Text = "text"
|
||||
"""The text input type."""
|
||||
Json = "json"
|
||||
"""The JSON input type."""
|
||||
|
||||
def __repr__(self):
|
||||
"""Get a string representation."""
|
||||
return f'"{self.value}"'
|
||||
@ -10,8 +10,8 @@ from graphrag_common.factory import Factory
|
||||
from graphrag_common.factory.factory import ServiceScope
|
||||
from graphrag_storage.storage import Storage
|
||||
|
||||
from graphrag.config.enums import InputFileType
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.input_config import InputConfig
|
||||
from graphrag.index.input.input_file_type import InputFileType
|
||||
from graphrag.index.input.input_reader import InputReader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -61,18 +61,18 @@ def create_input_reader(config: InputConfig, storage: Storage) -> InputReader:
|
||||
|
||||
if input_strategy not in input_reader_factory:
|
||||
match input_strategy:
|
||||
case InputFileType.csv:
|
||||
case InputFileType.Csv:
|
||||
from graphrag.index.input.csv import CSVFileReader
|
||||
|
||||
register_input_reader(InputFileType.csv, CSVFileReader)
|
||||
case InputFileType.text:
|
||||
register_input_reader(InputFileType.Csv, CSVFileReader)
|
||||
case InputFileType.Text:
|
||||
from graphrag.index.input.text import TextFileReader
|
||||
|
||||
register_input_reader(InputFileType.text, TextFileReader)
|
||||
case InputFileType.json:
|
||||
register_input_reader(InputFileType.Text, TextFileReader)
|
||||
case InputFileType.Json:
|
||||
from graphrag.index.input.json import JSONFileReader
|
||||
|
||||
register_input_reader(InputFileType.json, JSONFileReader)
|
||||
register_input_reader(InputFileType.Json, JSONFileReader)
|
||||
case _:
|
||||
msg = f"InputConfig.file_type '{input_strategy}' is not registered in the InputReaderFactory. Registered types: {', '.join(input_reader_factory.keys())}."
|
||||
raise ValueError(msg)
|
||||
|
||||
@ -18,7 +18,6 @@ from graphrag.config.models.extract_graph_nlp_config import (
|
||||
)
|
||||
from graphrag.config.models.global_search_config import GlobalSearchConfig
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.config.models.local_search_config import LocalSearchConfig
|
||||
from graphrag.config.models.prune_graph_config import PruneGraphConfig
|
||||
@ -28,6 +27,7 @@ from graphrag.config.models.summarize_descriptions_config import (
|
||||
SummarizeDescriptionsConfig,
|
||||
)
|
||||
from graphrag.config.models.vector_store_config import VectorStoreConfig
|
||||
from graphrag.index.input.input_config import InputConfig
|
||||
from graphrag_cache import CacheConfig
|
||||
from graphrag_storage import StorageConfig
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.config.enums import InputFileType
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.input_config import InputConfig
|
||||
from graphrag.index.input.input_file_type import InputFileType
|
||||
from graphrag.index.input.input_reader_factory import (
|
||||
create_input_reader,
|
||||
)
|
||||
@ -14,7 +14,7 @@ async def test_csv_loader_one_file():
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-csv",
|
||||
),
|
||||
file_type=InputFileType.csv,
|
||||
file_type=InputFileType.Csv,
|
||||
file_pattern=".*\\.csv$",
|
||||
)
|
||||
storage = create_storage(config.storage)
|
||||
@ -29,7 +29,7 @@ async def test_csv_loader_one_file_with_title():
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-csv",
|
||||
),
|
||||
file_type=InputFileType.csv,
|
||||
file_type=InputFileType.Csv,
|
||||
file_pattern=".*\\.csv$",
|
||||
title_column="title",
|
||||
)
|
||||
@ -45,7 +45,7 @@ async def test_csv_loader_one_file_with_metadata():
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-csv",
|
||||
),
|
||||
file_type=InputFileType.csv,
|
||||
file_type=InputFileType.Csv,
|
||||
file_pattern=".*\\.csv$",
|
||||
title_column="title",
|
||||
metadata=["title"],
|
||||
@ -62,7 +62,7 @@ async def test_csv_loader_multiple_files():
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/multiple-csvs",
|
||||
),
|
||||
file_type=InputFileType.csv,
|
||||
file_type=InputFileType.Csv,
|
||||
file_pattern=".*\\.csv$",
|
||||
)
|
||||
storage = create_storage(config.storage)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.config.enums import InputFileType
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.input_config import InputConfig
|
||||
from graphrag.index.input.input_file_type import InputFileType
|
||||
from graphrag.index.input.input_reader_factory import create_input_reader
|
||||
from graphrag_storage import StorageConfig, create_storage
|
||||
|
||||
@ -12,7 +12,7 @@ async def test_json_loader_one_file_one_object():
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-json-one-object",
|
||||
),
|
||||
file_type=InputFileType.json,
|
||||
file_type=InputFileType.Json,
|
||||
file_pattern=".*\\.json$",
|
||||
)
|
||||
storage = create_storage(config.storage)
|
||||
@ -27,7 +27,7 @@ async def test_json_loader_one_file_multiple_objects():
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-json-multiple-objects",
|
||||
),
|
||||
file_type=InputFileType.json,
|
||||
file_type=InputFileType.Json,
|
||||
file_pattern=".*\\.json$",
|
||||
)
|
||||
storage = create_storage(config.storage)
|
||||
@ -43,7 +43,7 @@ async def test_json_loader_one_file_with_title():
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-json-one-object",
|
||||
),
|
||||
file_type=InputFileType.json,
|
||||
file_type=InputFileType.Json,
|
||||
file_pattern=".*\\.json$",
|
||||
title_column="title",
|
||||
)
|
||||
@ -59,7 +59,7 @@ async def test_json_loader_one_file_with_metadata():
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-json-one-object",
|
||||
),
|
||||
file_type=InputFileType.json,
|
||||
file_type=InputFileType.Json,
|
||||
file_pattern=".*\\.json$",
|
||||
title_column="title",
|
||||
metadata=["title"],
|
||||
@ -76,7 +76,7 @@ async def test_json_loader_multiple_files():
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/multiple-jsons",
|
||||
),
|
||||
file_type=InputFileType.json,
|
||||
file_type=InputFileType.Json,
|
||||
file_pattern=".*\\.json$",
|
||||
)
|
||||
storage = create_storage(config.storage)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.config.enums import InputFileType
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.input_config import InputConfig
|
||||
from graphrag.index.input.input_file_type import InputFileType
|
||||
from graphrag.index.input.input_reader_factory import create_input_reader
|
||||
from graphrag_storage import StorageConfig, create_storage
|
||||
|
||||
@ -12,7 +12,7 @@ async def test_txt_loader_one_file():
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-txt",
|
||||
),
|
||||
file_type=InputFileType.text,
|
||||
file_type=InputFileType.Text,
|
||||
file_pattern=".*\\.txt$",
|
||||
)
|
||||
storage = create_storage(config.storage)
|
||||
@ -27,7 +27,7 @@ async def test_txt_loader_one_file_with_metadata():
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-txt",
|
||||
),
|
||||
file_type=InputFileType.text,
|
||||
file_type=InputFileType.Text,
|
||||
file_pattern=".*\\.txt$",
|
||||
metadata=["title"],
|
||||
)
|
||||
@ -44,7 +44,7 @@ async def test_txt_loader_multiple_files():
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/multiple-txts",
|
||||
),
|
||||
file_type=InputFileType.text,
|
||||
file_type=InputFileType.Text,
|
||||
file_pattern=".*\\.txt$",
|
||||
)
|
||||
storage = create_storage(config.storage)
|
||||
|
||||
@ -7,6 +7,7 @@ import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import graphrag.api as api
|
||||
import streamlit as st
|
||||
from knowledge_loader.data_sources.loader import (
|
||||
create_datasource,
|
||||
@ -17,8 +18,6 @@ from rag.typing import SearchResult, SearchType
|
||||
from state.session_variables import SessionVariables
|
||||
from ui.search import display_search_result
|
||||
|
||||
import graphrag.api as api
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user