mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Clean up optional column configs
This commit is contained in:
parent
c73263d02c
commit
b265612828
@ -244,8 +244,9 @@ class InputDefaults:
|
||||
file_type: ClassVar[InputFileType] = InputFileType.Text
|
||||
encoding: str | None = None
|
||||
file_pattern: None = None
|
||||
text_column: str = "text"
|
||||
id_column: None = None
|
||||
title_column: None = None
|
||||
text_column: None = None
|
||||
metadata: None = None
|
||||
|
||||
|
||||
|
||||
@ -29,7 +29,9 @@ class CSVFileReader(InputReader):
|
||||
"""
|
||||
buffer = BytesIO(await self._storage.get(path, as_bytes=True))
|
||||
data = pd.read_csv(buffer, encoding=self._encoding)
|
||||
data = process_data_columns(data, path, self._text_column, self._title_column)
|
||||
data = process_data_columns(
|
||||
data, path, self._id_column, self._title_column, self._text_column
|
||||
)
|
||||
creation_date = await self._storage.get_creation_date(path)
|
||||
data["creation_date"] = data.apply(lambda _: creation_date, axis=1)
|
||||
return data
|
||||
|
||||
@ -34,14 +34,18 @@ class InputConfig(BaseModel):
|
||||
description="The input file pattern to use.",
|
||||
default=graphrag_config_defaults.input.file_pattern,
|
||||
)
|
||||
text_column: str | None = Field(
|
||||
description="The input text column to use.",
|
||||
default=graphrag_config_defaults.input.text_column,
|
||||
id_column: str | None = Field(
|
||||
description="The input ID column to use.",
|
||||
default=graphrag_config_defaults.input.id_column,
|
||||
)
|
||||
title_column: str | None = Field(
|
||||
description="The input title column to use.",
|
||||
default=graphrag_config_defaults.input.title_column,
|
||||
)
|
||||
text_column: str | None = Field(
|
||||
description="The input text column to use.",
|
||||
default=graphrag_config_defaults.input.text_column,
|
||||
)
|
||||
metadata: list[str] | None = Field(
|
||||
description="The document attribute columns to use.",
|
||||
default=graphrag_config_defaults.input.metadata,
|
||||
|
||||
@ -27,16 +27,18 @@ class InputReader(metaclass=ABCMeta):
|
||||
file_type: str,
|
||||
encoding: str = "utf-8",
|
||||
file_pattern: str | None = None,
|
||||
text_column: str | None = None,
|
||||
id_column: str | None = None,
|
||||
title_column: str | None = None,
|
||||
text_column: str = "text",
|
||||
metadata: list[str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._storage = storage
|
||||
self._file_type = file_type
|
||||
self._encoding = encoding
|
||||
self._text_column = text_column
|
||||
self._id_column = id_column
|
||||
self._title_column = title_column
|
||||
self._text_column = text_column
|
||||
self._metadata = metadata
|
||||
|
||||
# built-in readers set a default pattern if none is provided
|
||||
@ -52,7 +54,6 @@ class InputReader(metaclass=ABCMeta):
|
||||
async def read_files(self) -> pd.DataFrame:
|
||||
"""Load files from storage and apply a loader function based on file type. Process metadata on the results if needed."""
|
||||
files = list(self._storage.find(re.compile(self._file_pattern)))
|
||||
|
||||
if len(files) == 0:
|
||||
msg = f"No {self._file_type} files found in storage" # TODO: use a storage __str__ to define it per impl
|
||||
logger.warning(msg)
|
||||
|
||||
@ -32,7 +32,9 @@ class JSONFileReader(InputReader):
|
||||
# json file could just be a single object, or an array of objects
|
||||
rows = as_json if isinstance(as_json, list) else [as_json]
|
||||
data = pd.DataFrame(rows)
|
||||
data = process_data_columns(data, path, self._text_column, self._title_column)
|
||||
data = process_data_columns(
|
||||
data, path, self._id_column, self._title_column, self._text_column
|
||||
)
|
||||
creation_date = await self._storage.get_creation_date(path)
|
||||
data["creation_date"] = data.apply(lambda _: creation_date, axis=1)
|
||||
|
||||
|
||||
@ -15,32 +15,26 @@ logger = logging.getLogger(__name__)
|
||||
def process_data_columns(
|
||||
documents: pd.DataFrame,
|
||||
path: str,
|
||||
text_column: str | None,
|
||||
title_column: str | None,
|
||||
id_column: str | None = None,
|
||||
title_column: str | None = None,
|
||||
text_column: str = "text",
|
||||
) -> pd.DataFrame:
|
||||
"""Process configured data columns of a DataFrame."""
|
||||
if "id" not in documents.columns:
|
||||
# id is optional - generate from harvest from df or hash from text
|
||||
if id_column is not None:
|
||||
documents["id"] = documents.apply(lambda x: x[id_column], axis=1)
|
||||
else:
|
||||
documents["id"] = documents.apply(
|
||||
lambda x: gen_sha512_hash(x, x.keys()), axis=1
|
||||
)
|
||||
if text_column is not None and "text" not in documents.columns:
|
||||
if text_column not in documents.columns:
|
||||
logger.warning(
|
||||
"text_column %s not found in csv file %s",
|
||||
text_column,
|
||||
path,
|
||||
)
|
||||
else:
|
||||
documents["text"] = documents.apply(lambda x: x[text_column], axis=1)
|
||||
|
||||
# title is optional - harvest from df or use filename
|
||||
if title_column is not None:
|
||||
if title_column not in documents.columns:
|
||||
logger.warning(
|
||||
"title_column %s not found in csv file %s",
|
||||
title_column,
|
||||
path,
|
||||
)
|
||||
else:
|
||||
documents["title"] = documents.apply(lambda x: x[title_column], axis=1)
|
||||
documents["title"] = documents.apply(lambda x: x[title_column], axis=1)
|
||||
else:
|
||||
documents["title"] = documents.apply(lambda _: path, axis=1)
|
||||
|
||||
# text column is required - harvest from df
|
||||
documents["text"] = documents.apply(lambda x: x[text_column], axis=1)
|
||||
|
||||
return documents
|
||||
|
||||
Loading…
Reference in New Issue
Block a user