mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Remove embedding column from df loaders
This commit is contained in:
parent
a8c1772340
commit
f4a20cd73d
@ -11,14 +11,12 @@ from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.data_model.community import Community
|
||||
from graphrag.data_model.community_report import CommunityReport
|
||||
from graphrag.data_model.covariate import Covariate
|
||||
from graphrag.data_model.entity import Entity
|
||||
from graphrag.data_model.relationship import Relationship
|
||||
from graphrag.data_model.text_unit import TextUnit
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.language_model.protocol.base import EmbeddingModel
|
||||
from graphrag.query.input.loaders.dfs import (
|
||||
read_communities,
|
||||
@ -76,8 +74,6 @@ def read_indexer_reports(
|
||||
final_communities: pd.DataFrame,
|
||||
community_level: int | None,
|
||||
dynamic_community_selection: bool = False,
|
||||
content_embedding_col: str = "full_content_embedding",
|
||||
config: GraphRagConfig | None = None,
|
||||
) -> list[CommunityReport]:
|
||||
"""Read in the Community Reports from the raw indexing outputs.
|
||||
|
||||
@ -102,29 +98,7 @@ def read_indexer_reports(
|
||||
filtered_community_df, on="community", how="inner"
|
||||
)
|
||||
|
||||
if config and (
|
||||
content_embedding_col not in reports_df.columns
|
||||
or reports_df.loc[:, content_embedding_col].isna().any()
|
||||
):
|
||||
# TODO: Find a way to retrieve the right embedding model id.
|
||||
embedding_model_settings = config.get_language_model_config(
|
||||
"default_embedding_model"
|
||||
)
|
||||
embedder = ModelManager().get_or_create_embedding_model(
|
||||
name="default_embedding",
|
||||
model_type=embedding_model_settings.type,
|
||||
config=embedding_model_settings,
|
||||
)
|
||||
reports_df = embed_community_reports(
|
||||
reports_df, embedder, embedding_col=content_embedding_col
|
||||
)
|
||||
|
||||
return read_community_reports(
|
||||
df=reports_df,
|
||||
id_col="id",
|
||||
short_id_col="community",
|
||||
content_embedding_col=content_embedding_col,
|
||||
)
|
||||
return read_community_reports(df=reports_df, id_col="id", short_id_col="community")
|
||||
|
||||
|
||||
def read_indexer_report_embeddings(
|
||||
|
||||
@ -197,7 +197,6 @@ def read_community_reports(
|
||||
summary_col: str = "summary",
|
||||
content_col: str = "full_content",
|
||||
rank_col: str | None = "rank",
|
||||
content_embedding_col: str | None = "full_content_embedding",
|
||||
attributes_cols: list[str] | None = None,
|
||||
) -> list[CommunityReport]:
|
||||
"""Read community reports from a dataframe using pre-converted records."""
|
||||
@ -213,9 +212,6 @@ def read_community_reports(
|
||||
summary=to_str(row, summary_col),
|
||||
full_content=to_str(row, content_col),
|
||||
rank=to_optional_float(row, rank_col),
|
||||
full_content_embedding=to_optional_list(
|
||||
row, content_embedding_col, item_type=float
|
||||
),
|
||||
attributes=(
|
||||
{col: row.get(col) for col in attributes_cols}
|
||||
if attributes_cols
|
||||
|
||||
Loading…
Reference in New Issue
Block a user