mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-13 16:47:20 +08:00
partially resolved merge conflicts
This commit is contained in:
parent
9ca67643b4
commit
14b1eccbff
22
.semversioner/1.0.1.json
Normal file
22
.semversioner/1.0.1.json
Normal file
@ -0,0 +1,22 @@
|
||||
{
|
||||
"changes": [
|
||||
{
|
||||
"description": "Fix encoding model config parsing",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Fix exception on error callbacks",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Manage llm instances inside a cached singleton. Check for empty dfs after entity/relationship extraction",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Respect encoding_model option",
|
||||
"type": "patch"
|
||||
}
|
||||
],
|
||||
"created_at": "2024-12-18T23:12:52+00:00",
|
||||
"version": "1.0.1"
|
||||
}
|
||||
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Implement cosmosdb storage option for cache and output"
|
||||
}
|
||||
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Streamline flows."
|
||||
}
|
||||
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Move extractor code to co-locate with operations."
|
||||
}
|
||||
40
graphrag/index/llm/manager.py
Normal file
40
graphrag/index/llm/manager.py
Normal file
@ -0,0 +1,40 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LLM Manager singleton."""
|
||||
|
||||
from functools import cache
|
||||
|
||||
from fnllm import ChatLLM, EmbeddingsLLM
|
||||
|
||||
|
||||
@cache
|
||||
class ChatLLMSingleton:
|
||||
"""A singleton class for the chat LLM instances."""
|
||||
|
||||
def __init__(self):
|
||||
self.llm_dict = {}
|
||||
|
||||
def set_llm(self, name, llm):
|
||||
"""Add an LLM to the dictionary."""
|
||||
self.llm_dict[name] = llm
|
||||
|
||||
def get_llm(self, name) -> ChatLLM | None:
|
||||
"""Get an LLM from the dictionary."""
|
||||
return self.llm_dict.get(name)
|
||||
|
||||
|
||||
@cache
|
||||
class EmbeddingsLLMSingleton:
|
||||
"""A singleton class for the embeddings LLM instances."""
|
||||
|
||||
def __init__(self):
|
||||
self.llm_dict = {}
|
||||
|
||||
def set_llm(self, name, llm):
|
||||
"""Add an LLM to the dictionary."""
|
||||
self.llm_dict[name] = llm
|
||||
|
||||
def get_llm(self, name) -> EmbeddingsLLM | None:
|
||||
"""Get an LLM from the dictionary."""
|
||||
return self.llm_dict.get(name)
|
||||
15
graphrag/index/operations/compute_degree.py
Normal file
15
graphrag/index/operations/compute_degree.py
Normal file
@ -0,0 +1,15 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing create_graph definition."""
|
||||
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def compute_degree(graph: nx.Graph) -> pd.DataFrame:
|
||||
"""Create a new DataFrame with the degree of each node in the graph."""
|
||||
return pd.DataFrame([
|
||||
{"title": node, "degree": int(degree)}
|
||||
for node, degree in graph.degree # type: ignore
|
||||
])
|
||||
43
graphrag/index/operations/embed_graph/embed_node2vec.py
Normal file
43
graphrag/index/operations/embed_graph/embed_node2vec.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Utilities to generate graph embeddings."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeEmbeddings:
|
||||
"""Node embeddings class definition."""
|
||||
|
||||
nodes: list[str]
|
||||
embeddings: np.ndarray
|
||||
|
||||
|
||||
def embed_node2vec(
|
||||
graph: nx.Graph | nx.DiGraph,
|
||||
dimensions: int = 1536,
|
||||
num_walks: int = 10,
|
||||
walk_length: int = 40,
|
||||
window_size: int = 2,
|
||||
iterations: int = 3,
|
||||
random_seed: int = 86,
|
||||
) -> NodeEmbeddings:
|
||||
"""Generate node embeddings using Node2Vec."""
|
||||
# NOTE: This import is done here to reduce the initial import time of the graphrag package
|
||||
import graspologic as gc
|
||||
|
||||
# generate embedding
|
||||
lcc_tensors = gc.embed.node2vec_embed( # type: ignore
|
||||
graph=graph,
|
||||
dimensions=dimensions,
|
||||
window_size=window_size,
|
||||
iterations=iterations,
|
||||
num_walks=num_walks,
|
||||
walk_length=walk_length,
|
||||
random_seed=random_seed,
|
||||
)
|
||||
return NodeEmbeddings(embeddings=lcc_tensors[0], nodes=lcc_tensors[1])
|
||||
242
graphrag/index/operations/extract_covariates/claim_extractor.py
Normal file
242
graphrag/index/operations/extract_covariates/claim_extractor.py
Normal file
@ -0,0 +1,242 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'ClaimExtractorResult' and 'ClaimExtractor' models."""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
from fnllm import ChatLLM
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
from graphrag.prompts.index.claim_extraction import (
|
||||
CLAIM_EXTRACTION_PROMPT,
|
||||
CONTINUE_PROMPT,
|
||||
LOOP_PROMPT,
|
||||
)
|
||||
|
||||
DEFAULT_TUPLE_DELIMITER = "<|>"
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClaimExtractorResult:
|
||||
"""Claim extractor result class definition."""
|
||||
|
||||
output: list[dict]
|
||||
source_docs: dict[str, Any]
|
||||
|
||||
|
||||
class ClaimExtractor:
|
||||
"""Claim extractor class definition."""
|
||||
|
||||
_llm: ChatLLM
|
||||
_extraction_prompt: str
|
||||
_summary_prompt: str
|
||||
_output_formatter_prompt: str
|
||||
_input_text_key: str
|
||||
_input_entity_spec_key: str
|
||||
_input_claim_description_key: str
|
||||
_tuple_delimiter_key: str
|
||||
_record_delimiter_key: str
|
||||
_completion_delimiter_key: str
|
||||
_max_gleanings: int
|
||||
_on_error: ErrorHandlerFn
|
||||
_loop_args: dict[str, Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_invoker: ChatLLM,
|
||||
extraction_prompt: str | None = None,
|
||||
input_text_key: str | None = None,
|
||||
input_entity_spec_key: str | None = None,
|
||||
input_claim_description_key: str | None = None,
|
||||
input_resolved_entities_key: str | None = None,
|
||||
tuple_delimiter_key: str | None = None,
|
||||
record_delimiter_key: str | None = None,
|
||||
completion_delimiter_key: str | None = None,
|
||||
encoding_model: str | None = None,
|
||||
max_gleanings: int | None = None,
|
||||
on_error: ErrorHandlerFn | None = None,
|
||||
):
|
||||
"""Init method definition."""
|
||||
self._llm = llm_invoker
|
||||
self._extraction_prompt = extraction_prompt or CLAIM_EXTRACTION_PROMPT
|
||||
self._input_text_key = input_text_key or "input_text"
|
||||
self._input_entity_spec_key = input_entity_spec_key or "entity_specs"
|
||||
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
|
||||
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
|
||||
self._completion_delimiter_key = (
|
||||
completion_delimiter_key or "completion_delimiter"
|
||||
)
|
||||
self._input_claim_description_key = (
|
||||
input_claim_description_key or "claim_description"
|
||||
)
|
||||
self._input_resolved_entities_key = (
|
||||
input_resolved_entities_key or "resolved_entities"
|
||||
)
|
||||
self._max_gleanings = (
|
||||
max_gleanings if max_gleanings is not None else defs.CLAIM_MAX_GLEANINGS
|
||||
)
|
||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
||||
|
||||
# Construct the looping arguments
|
||||
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
|
||||
yes = f"{encoding.encode('YES')[0]}"
|
||||
no = f"{encoding.encode('NO')[0]}"
|
||||
self._loop_args = {"logit_bias": {yes: 100, no: 100}, "max_tokens": 1}
|
||||
|
||||
async def __call__(
|
||||
self, inputs: dict[str, Any], prompt_variables: dict | None = None
|
||||
) -> ClaimExtractorResult:
|
||||
"""Call method definition."""
|
||||
if prompt_variables is None:
|
||||
prompt_variables = {}
|
||||
texts = inputs[self._input_text_key]
|
||||
entity_spec = str(inputs[self._input_entity_spec_key])
|
||||
claim_description = inputs[self._input_claim_description_key]
|
||||
resolved_entities = inputs.get(self._input_resolved_entities_key, {})
|
||||
source_doc_map = {}
|
||||
|
||||
prompt_args = {
|
||||
self._input_entity_spec_key: entity_spec,
|
||||
self._input_claim_description_key: claim_description,
|
||||
self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
|
||||
or DEFAULT_TUPLE_DELIMITER,
|
||||
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
|
||||
or DEFAULT_RECORD_DELIMITER,
|
||||
self._completion_delimiter_key: prompt_variables.get(
|
||||
self._completion_delimiter_key
|
||||
)
|
||||
or DEFAULT_COMPLETION_DELIMITER,
|
||||
}
|
||||
|
||||
all_claims: list[dict] = []
|
||||
for doc_index, text in enumerate(texts):
|
||||
document_id = f"d{doc_index}"
|
||||
try:
|
||||
claims = await self._process_document(prompt_args, text, doc_index)
|
||||
all_claims += [
|
||||
self._clean_claim(c, document_id, resolved_entities) for c in claims
|
||||
]
|
||||
source_doc_map[document_id] = text
|
||||
except Exception as e:
|
||||
log.exception("error extracting claim")
|
||||
self._on_error(
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
{"doc_index": doc_index, "text": text},
|
||||
)
|
||||
continue
|
||||
|
||||
return ClaimExtractorResult(
|
||||
output=all_claims,
|
||||
source_docs=source_doc_map,
|
||||
)
|
||||
|
||||
def _clean_claim(
|
||||
self, claim: dict, document_id: str, resolved_entities: dict
|
||||
) -> dict:
|
||||
# clean the parsed claims to remove any claims with status = False
|
||||
obj = claim.get("object_id", claim.get("object"))
|
||||
subject = claim.get("subject_id", claim.get("subject"))
|
||||
|
||||
# If subject or object in resolved entities, then replace with resolved entity
|
||||
obj = resolved_entities.get(obj, obj)
|
||||
subject = resolved_entities.get(subject, subject)
|
||||
claim["object_id"] = obj
|
||||
claim["subject_id"] = subject
|
||||
return claim
|
||||
|
||||
async def _process_document(
|
||||
self, prompt_args: dict, doc, doc_index: int
|
||||
) -> list[dict]:
|
||||
record_delimiter = prompt_args.get(
|
||||
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
|
||||
)
|
||||
completion_delimiter = prompt_args.get(
|
||||
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
|
||||
)
|
||||
|
||||
response = await self._llm(
|
||||
self._extraction_prompt.format(**{
|
||||
self._input_text_key: doc,
|
||||
**prompt_args,
|
||||
})
|
||||
)
|
||||
results = response.output.content or ""
|
||||
claims = results.strip().removesuffix(completion_delimiter)
|
||||
|
||||
# Repeat to ensure we maximize entity count
|
||||
for i in range(self._max_gleanings):
|
||||
response = await self._llm(
|
||||
CONTINUE_PROMPT,
|
||||
name=f"extract-continuation-{i}",
|
||||
history=response.history,
|
||||
)
|
||||
extension = response.output.content or ""
|
||||
claims += record_delimiter + extension.strip().removesuffix(
|
||||
completion_delimiter
|
||||
)
|
||||
|
||||
# If this isn't the last loop, check to see if we should continue
|
||||
if i >= self._max_gleanings - 1:
|
||||
break
|
||||
|
||||
response = await self._llm(
|
||||
LOOP_PROMPT,
|
||||
name=f"extract-loopcheck-{i}",
|
||||
history=response.history,
|
||||
model_parameters=self._loop_args,
|
||||
)
|
||||
if response.output.content != "YES":
|
||||
break
|
||||
|
||||
return self._parse_claim_tuples(results, prompt_args)
|
||||
|
||||
def _parse_claim_tuples(
|
||||
self, claims: str, prompt_variables: dict
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Parse claim tuples."""
|
||||
record_delimiter = prompt_variables.get(
|
||||
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
|
||||
)
|
||||
completion_delimiter = prompt_variables.get(
|
||||
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
|
||||
)
|
||||
tuple_delimiter = prompt_variables.get(
|
||||
self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER
|
||||
)
|
||||
|
||||
def pull_field(index: int, fields: list[str]) -> str | None:
|
||||
return fields[index].strip() if len(fields) > index else None
|
||||
|
||||
result: list[dict[str, Any]] = []
|
||||
claims_values = (
|
||||
claims.strip().removesuffix(completion_delimiter).split(record_delimiter)
|
||||
)
|
||||
for claim in claims_values:
|
||||
claim = claim.strip().removeprefix("(").removesuffix(")")
|
||||
|
||||
# Ignore the completion delimiter
|
||||
if claim == completion_delimiter:
|
||||
continue
|
||||
|
||||
claim_fields = claim.split(tuple_delimiter)
|
||||
result.append({
|
||||
"subject_id": pull_field(0, claim_fields),
|
||||
"object_id": pull_field(1, claim_fields),
|
||||
"type": pull_field(2, claim_fields),
|
||||
"status": pull_field(3, claim_fields),
|
||||
"start_date": pull_field(4, claim_fields),
|
||||
"end_date": pull_field(5, claim_fields),
|
||||
"description": pull_field(6, claim_fields),
|
||||
"source_text": pull_field(7, claim_fields),
|
||||
})
|
||||
return result
|
||||
308
graphrag/index/operations/extract_entities/graph_extractor.py
Normal file
308
graphrag/index/operations/extract_entities/graph_extractor.py
Normal file
@ -0,0 +1,308 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'GraphExtractionResult' and 'GraphExtractor' models."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import networkx as nx
|
||||
import tiktoken
|
||||
from fnllm import ChatLLM
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
from graphrag.index.utils.string import clean_str
|
||||
from graphrag.prompts.index.entity_extraction import (
|
||||
CONTINUE_PROMPT,
|
||||
GRAPH_EXTRACTION_PROMPT,
|
||||
LOOP_PROMPT,
|
||||
)
|
||||
|
||||
DEFAULT_TUPLE_DELIMITER = "<|>"
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
|
||||
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphExtractionResult:
|
||||
"""Unipartite graph extraction result class definition."""
|
||||
|
||||
output: nx.Graph
|
||||
source_docs: dict[Any, Any]
|
||||
|
||||
|
||||
class GraphExtractor:
|
||||
"""Unipartite graph extractor class definition."""
|
||||
|
||||
_llm: ChatLLM
|
||||
_join_descriptions: bool
|
||||
_tuple_delimiter_key: str
|
||||
_record_delimiter_key: str
|
||||
_entity_types_key: str
|
||||
_input_text_key: str
|
||||
_completion_delimiter_key: str
|
||||
_entity_name_key: str
|
||||
_input_descriptions_key: str
|
||||
_extraction_prompt: str
|
||||
_summarization_prompt: str
|
||||
_loop_args: dict[str, Any]
|
||||
_max_gleanings: int
|
||||
_on_error: ErrorHandlerFn
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_invoker: ChatLLM,
|
||||
tuple_delimiter_key: str | None = None,
|
||||
record_delimiter_key: str | None = None,
|
||||
input_text_key: str | None = None,
|
||||
entity_types_key: str | None = None,
|
||||
completion_delimiter_key: str | None = None,
|
||||
prompt: str | None = None,
|
||||
join_descriptions=True,
|
||||
encoding_model: str | None = None,
|
||||
max_gleanings: int | None = None,
|
||||
on_error: ErrorHandlerFn | None = None,
|
||||
):
|
||||
"""Init method definition."""
|
||||
# TODO: streamline construction
|
||||
self._llm = llm_invoker
|
||||
self._join_descriptions = join_descriptions
|
||||
self._input_text_key = input_text_key or "input_text"
|
||||
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
|
||||
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
|
||||
self._completion_delimiter_key = (
|
||||
completion_delimiter_key or "completion_delimiter"
|
||||
)
|
||||
self._entity_types_key = entity_types_key or "entity_types"
|
||||
self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT
|
||||
self._max_gleanings = (
|
||||
max_gleanings
|
||||
if max_gleanings is not None
|
||||
else defs.ENTITY_EXTRACTION_MAX_GLEANINGS
|
||||
)
|
||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
||||
|
||||
# Construct the looping arguments
|
||||
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
|
||||
yes = f"{encoding.encode('YES')[0]}"
|
||||
no = f"{encoding.encode('NO')[0]}"
|
||||
self._loop_args = {"logit_bias": {yes: 100, no: 100}, "max_tokens": 1}
|
||||
|
||||
async def __call__(
|
||||
self, texts: list[str], prompt_variables: dict[str, Any] | None = None
|
||||
) -> GraphExtractionResult:
|
||||
"""Call method definition."""
|
||||
if prompt_variables is None:
|
||||
prompt_variables = {}
|
||||
all_records: dict[int, str] = {}
|
||||
source_doc_map: dict[int, str] = {}
|
||||
|
||||
# Wire defaults into the prompt variables
|
||||
prompt_variables = {
|
||||
**prompt_variables,
|
||||
self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
|
||||
or DEFAULT_TUPLE_DELIMITER,
|
||||
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
|
||||
or DEFAULT_RECORD_DELIMITER,
|
||||
self._completion_delimiter_key: prompt_variables.get(
|
||||
self._completion_delimiter_key
|
||||
)
|
||||
or DEFAULT_COMPLETION_DELIMITER,
|
||||
self._entity_types_key: ",".join(
|
||||
prompt_variables[self._entity_types_key] or DEFAULT_ENTITY_TYPES
|
||||
),
|
||||
}
|
||||
|
||||
for doc_index, text in enumerate(texts):
|
||||
try:
|
||||
# Invoke the entity extraction
|
||||
result = await self._process_document(text, prompt_variables)
|
||||
source_doc_map[doc_index] = text
|
||||
all_records[doc_index] = result
|
||||
except Exception as e:
|
||||
log.exception("error extracting graph")
|
||||
self._on_error(
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
{
|
||||
"doc_index": doc_index,
|
||||
"text": text,
|
||||
},
|
||||
)
|
||||
|
||||
output = await self._process_results(
|
||||
all_records,
|
||||
prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER),
|
||||
prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER),
|
||||
)
|
||||
|
||||
return GraphExtractionResult(
|
||||
output=output,
|
||||
source_docs=source_doc_map,
|
||||
)
|
||||
|
||||
async def _process_document(
|
||||
self, text: str, prompt_variables: dict[str, str]
|
||||
) -> str:
|
||||
response = await self._llm(
|
||||
self._extraction_prompt.format(**{
|
||||
**prompt_variables,
|
||||
self._input_text_key: text,
|
||||
}),
|
||||
)
|
||||
results = response.output.content or ""
|
||||
|
||||
# Repeat to ensure we maximize entity count
|
||||
for i in range(self._max_gleanings):
|
||||
response = await self._llm(
|
||||
CONTINUE_PROMPT,
|
||||
name=f"extract-continuation-{i}",
|
||||
history=response.history,
|
||||
)
|
||||
results += response.output.content or ""
|
||||
|
||||
# if this is the final glean, don't bother updating the continuation flag
|
||||
if i >= self._max_gleanings - 1:
|
||||
break
|
||||
|
||||
response = await self._llm(
|
||||
LOOP_PROMPT,
|
||||
name=f"extract-loopcheck-{i}",
|
||||
history=response.history,
|
||||
model_parameters=self._loop_args,
|
||||
)
|
||||
if response.output != "YES":
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
async def _process_results(
|
||||
self,
|
||||
results: dict[int, str],
|
||||
tuple_delimiter: str,
|
||||
record_delimiter: str,
|
||||
) -> nx.Graph:
|
||||
"""Parse the result string to create an undirected unipartite graph.
|
||||
|
||||
Args:
|
||||
- results - dict of results from the extraction chain
|
||||
- tuple_delimiter - delimiter between tuples in an output record, default is '<|>'
|
||||
- record_delimiter - delimiter between records, default is '##'
|
||||
Returns:
|
||||
- output - unipartite graph in graphML format
|
||||
"""
|
||||
graph = nx.Graph()
|
||||
for source_doc_id, extracted_data in results.items():
|
||||
records = [r.strip() for r in extracted_data.split(record_delimiter)]
|
||||
|
||||
for record in records:
|
||||
record = re.sub(r"^\(|\)$", "", record.strip())
|
||||
record_attributes = record.split(tuple_delimiter)
|
||||
|
||||
if record_attributes[0] == '"entity"' and len(record_attributes) >= 4:
|
||||
# add this record as a node in the G
|
||||
entity_name = clean_str(record_attributes[1].upper())
|
||||
entity_type = clean_str(record_attributes[2].upper())
|
||||
entity_description = clean_str(record_attributes[3])
|
||||
|
||||
if entity_name in graph.nodes():
|
||||
node = graph.nodes[entity_name]
|
||||
if self._join_descriptions:
|
||||
node["description"] = "\n".join(
|
||||
list({
|
||||
*_unpack_descriptions(node),
|
||||
entity_description,
|
||||
})
|
||||
)
|
||||
else:
|
||||
if len(entity_description) > len(node["description"]):
|
||||
node["description"] = entity_description
|
||||
node["source_id"] = ", ".join(
|
||||
list({
|
||||
*_unpack_source_ids(node),
|
||||
str(source_doc_id),
|
||||
})
|
||||
)
|
||||
node["type"] = (
|
||||
entity_type if entity_type != "" else node["type"]
|
||||
)
|
||||
else:
|
||||
graph.add_node(
|
||||
entity_name,
|
||||
type=entity_type,
|
||||
description=entity_description,
|
||||
source_id=str(source_doc_id),
|
||||
)
|
||||
|
||||
if (
|
||||
record_attributes[0] == '"relationship"'
|
||||
and len(record_attributes) >= 5
|
||||
):
|
||||
# add this record as edge
|
||||
source = clean_str(record_attributes[1].upper())
|
||||
target = clean_str(record_attributes[2].upper())
|
||||
edge_description = clean_str(record_attributes[3])
|
||||
edge_source_id = clean_str(str(source_doc_id))
|
||||
try:
|
||||
weight = float(record_attributes[-1])
|
||||
except ValueError:
|
||||
weight = 1.0
|
||||
|
||||
if source not in graph.nodes():
|
||||
graph.add_node(
|
||||
source,
|
||||
type="",
|
||||
description="",
|
||||
source_id=edge_source_id,
|
||||
)
|
||||
if target not in graph.nodes():
|
||||
graph.add_node(
|
||||
target,
|
||||
type="",
|
||||
description="",
|
||||
source_id=edge_source_id,
|
||||
)
|
||||
if graph.has_edge(source, target):
|
||||
edge_data = graph.get_edge_data(source, target)
|
||||
if edge_data is not None:
|
||||
weight += edge_data["weight"]
|
||||
if self._join_descriptions:
|
||||
edge_description = "\n".join(
|
||||
list({
|
||||
*_unpack_descriptions(edge_data),
|
||||
edge_description,
|
||||
})
|
||||
)
|
||||
edge_source_id = ", ".join(
|
||||
list({
|
||||
*_unpack_source_ids(edge_data),
|
||||
str(source_doc_id),
|
||||
})
|
||||
)
|
||||
graph.add_edge(
|
||||
source,
|
||||
target,
|
||||
weight=weight,
|
||||
description=edge_description,
|
||||
source_id=edge_source_id,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def _unpack_descriptions(data: Mapping) -> list[str]:
|
||||
value = data.get("description", None)
|
||||
return [] if value is None else value.split("\n")
|
||||
|
||||
|
||||
def _unpack_source_ids(data: Mapping) -> list[str]:
|
||||
value = data.get("source_id", None)
|
||||
return [] if value is None else value.split(", ")
|
||||
@ -0,0 +1,139 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing run_graph_intelligence, run_extract_entities and _create_text_splitter methods to run graph intelligence."""
|
||||
|
||||
import networkx as nx
|
||||
from datashaper import VerbCallbacks
|
||||
from fnllm import ChatLLM
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.index.llm.load_llm import load_llm, read_llm_params
|
||||
from graphrag.index.operations.extract_entities.graph_extractor import GraphExtractor
|
||||
from graphrag.index.operations.extract_entities.typing import (
|
||||
Document,
|
||||
EntityExtractionResult,
|
||||
EntityTypes,
|
||||
StrategyConfig,
|
||||
)
|
||||
from graphrag.index.text_splitting.text_splitting import (
|
||||
NoopTextSplitter,
|
||||
TextSplitter,
|
||||
TokenTextSplitter,
|
||||
)
|
||||
|
||||
|
||||
async def run_graph_intelligence(
|
||||
docs: list[Document],
|
||||
entity_types: EntityTypes,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
args: StrategyConfig,
|
||||
) -> EntityExtractionResult:
|
||||
"""Run the graph intelligence entity extraction strategy."""
|
||||
llm_config = read_llm_params(args.get("llm", {}))
|
||||
llm = load_llm("entity_extraction", llm_config, callbacks=callbacks, cache=cache)
|
||||
return await run_extract_entities(llm, docs, entity_types, callbacks, args)
|
||||
|
||||
|
||||
async def run_extract_entities(
|
||||
llm: ChatLLM,
|
||||
docs: list[Document],
|
||||
entity_types: EntityTypes,
|
||||
callbacks: VerbCallbacks | None,
|
||||
args: StrategyConfig,
|
||||
) -> EntityExtractionResult:
|
||||
"""Run the entity extraction chain."""
|
||||
encoding_name = args.get("encoding_name", "cl100k_base")
|
||||
|
||||
# Chunking Arguments
|
||||
prechunked = args.get("prechunked", False)
|
||||
chunk_size = args.get("chunk_size", defs.CHUNK_SIZE)
|
||||
chunk_overlap = args.get("chunk_overlap", defs.CHUNK_OVERLAP)
|
||||
|
||||
# Extraction Arguments
|
||||
tuple_delimiter = args.get("tuple_delimiter", None)
|
||||
record_delimiter = args.get("record_delimiter", None)
|
||||
completion_delimiter = args.get("completion_delimiter", None)
|
||||
extraction_prompt = args.get("extraction_prompt", None)
|
||||
encoding_model = args.get("encoding_name", None)
|
||||
max_gleanings = args.get("max_gleanings", defs.ENTITY_EXTRACTION_MAX_GLEANINGS)
|
||||
|
||||
# note: We're not using UnipartiteGraphChain.from_params
|
||||
# because we want to pass "timeout" to the llm_kwargs
|
||||
text_splitter = _create_text_splitter(
|
||||
prechunked, chunk_size, chunk_overlap, encoding_name
|
||||
)
|
||||
|
||||
extractor = GraphExtractor(
|
||||
llm_invoker=llm,
|
||||
prompt=extraction_prompt,
|
||||
encoding_model=encoding_model,
|
||||
max_gleanings=max_gleanings,
|
||||
on_error=lambda e, s, d: (
|
||||
callbacks.error("Entity Extraction Error", e, s, d) if callbacks else None
|
||||
),
|
||||
)
|
||||
text_list = [doc.text.strip() for doc in docs]
|
||||
|
||||
# If it's not pre-chunked, then re-chunk the input
|
||||
if not prechunked:
|
||||
text_list = text_splitter.split_text("\n".join(text_list))
|
||||
|
||||
results = await extractor(
|
||||
list(text_list),
|
||||
{
|
||||
"entity_types": entity_types,
|
||||
"tuple_delimiter": tuple_delimiter,
|
||||
"record_delimiter": record_delimiter,
|
||||
"completion_delimiter": completion_delimiter,
|
||||
},
|
||||
)
|
||||
|
||||
graph = results.output
|
||||
# Map the "source_id" back to the "id" field
|
||||
for _, node in graph.nodes(data=True): # type: ignore
|
||||
if node is not None:
|
||||
node["source_id"] = ",".join(
|
||||
docs[int(id)].id for id in node["source_id"].split(",")
|
||||
)
|
||||
|
||||
for _, _, edge in graph.edges(data=True): # type: ignore
|
||||
if edge is not None:
|
||||
edge["source_id"] = ",".join(
|
||||
docs[int(id)].id for id in edge["source_id"].split(",")
|
||||
)
|
||||
|
||||
entities = [
|
||||
({"title": item[0], **(item[1] or {})})
|
||||
for item in graph.nodes(data=True)
|
||||
if item is not None
|
||||
]
|
||||
|
||||
relationships = nx.to_pandas_edgelist(graph)
|
||||
|
||||
return EntityExtractionResult(entities, relationships, graph)
|
||||
|
||||
|
||||
def _create_text_splitter(
|
||||
prechunked: bool, chunk_size: int, chunk_overlap: int, encoding_name: str
|
||||
) -> TextSplitter:
|
||||
"""Create a text splitter for the extraction chain.
|
||||
|
||||
Args:
|
||||
- prechunked - Whether the text is already chunked
|
||||
- chunk_size - The size of each chunk
|
||||
- chunk_overlap - The overlap between chunks
|
||||
- encoding_name - The name of the encoding to use
|
||||
Returns:
|
||||
- output - A text splitter
|
||||
"""
|
||||
if prechunked:
|
||||
return NoopTextSplitter()
|
||||
|
||||
return TokenTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
encoding_name=encoding_name,
|
||||
)
|
||||
66
graphrag/index/operations/extract_entities/nltk_strategy.py
Normal file
66
graphrag/index/operations/extract_entities/nltk_strategy.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing run method definition."""
|
||||
|
||||
import networkx as nx
|
||||
import nltk
|
||||
from datashaper import VerbCallbacks
|
||||
from nltk.corpus import words
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.index.operations.extract_entities.typing import (
|
||||
Document,
|
||||
EntityExtractionResult,
|
||||
EntityTypes,
|
||||
StrategyConfig,
|
||||
)
|
||||
|
||||
# Need to do this cause we're potentially multithreading, and nltk doesn't like that
|
||||
words.ensure_loaded()
|
||||
|
||||
|
||||
async def run( # noqa RUF029 async is required for interface
|
||||
docs: list[Document],
|
||||
entity_types: EntityTypes,
|
||||
callbacks: VerbCallbacks, # noqa ARG001
|
||||
cache: PipelineCache, # noqa ARG001
|
||||
args: StrategyConfig, # noqa ARG001
|
||||
) -> EntityExtractionResult:
|
||||
"""Run method definition."""
|
||||
entity_map = {}
|
||||
graph = nx.Graph()
|
||||
for doc in docs:
|
||||
connected_entities = []
|
||||
for chunk in nltk.ne_chunk(nltk.pos_tag(nltk.word_tokenize(doc.text))):
|
||||
if hasattr(chunk, "label"):
|
||||
entity_type = chunk.label().lower()
|
||||
if entity_type in entity_types:
|
||||
name = (" ".join(c[0] for c in chunk)).upper()
|
||||
connected_entities.append(name)
|
||||
if name not in entity_map:
|
||||
entity_map[name] = entity_type
|
||||
graph.add_node(
|
||||
name, type=entity_type, description=name, source_id=doc.id
|
||||
)
|
||||
|
||||
# connect the entities if they appear in the same document
|
||||
if len(connected_entities) > 1:
|
||||
for i in range(len(connected_entities)):
|
||||
for j in range(i + 1, len(connected_entities)):
|
||||
description = f"{connected_entities[i]} -> {connected_entities[j]}"
|
||||
graph.add_edge(
|
||||
connected_entities[i],
|
||||
connected_entities[j],
|
||||
description=description,
|
||||
source_id=doc.id,
|
||||
)
|
||||
|
||||
return EntityExtractionResult(
|
||||
entities=[
|
||||
{"type": entity_type, "title": name}
|
||||
for name, entity_type in entity_map.items()
|
||||
],
|
||||
relationships=[],
|
||||
graph=graph,
|
||||
)
|
||||
59
graphrag/index/operations/extract_entities/typing.py
Normal file
59
graphrag/index/operations/extract_entities/typing.py
Normal file
@ -0,0 +1,59 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'Document' and 'EntityExtractionResult' models."""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import networkx as nx
|
||||
from datashaper import VerbCallbacks
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
|
||||
ExtractedEntity = dict[str, Any]
|
||||
ExtractedRelationship = dict[str, Any]
|
||||
StrategyConfig = dict[str, Any]
|
||||
EntityTypes = list[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Document:
|
||||
"""Document class definition."""
|
||||
|
||||
text: str
|
||||
id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityExtractionResult:
|
||||
"""Entity extraction result class definition."""
|
||||
|
||||
entities: list[ExtractedEntity]
|
||||
relationships: list[ExtractedRelationship]
|
||||
graph: nx.Graph | None
|
||||
|
||||
|
||||
EntityExtractStrategy = Callable[
|
||||
[
|
||||
list[Document],
|
||||
EntityTypes,
|
||||
VerbCallbacks,
|
||||
PipelineCache,
|
||||
StrategyConfig,
|
||||
],
|
||||
Awaitable[EntityExtractionResult],
|
||||
]
|
||||
|
||||
|
||||
class ExtractEntityStrategyType(str, Enum):
|
||||
"""ExtractEntityStrategyType class definition."""
|
||||
|
||||
graph_intelligence = "graph_intelligence"
|
||||
nltk = "nltk"
|
||||
|
||||
def __repr__(self):
|
||||
"""Get a string representation."""
|
||||
return f'"{self.value}"'
|
||||
27
graphrag/index/operations/layout_graph/typing.py
Normal file
27
graphrag/index/operations/layout_graph/typing.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
# Use this for now instead of a wrapper
|
||||
"""A module containing 'NodePosition' model."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodePosition:
|
||||
"""Node position class definition."""
|
||||
|
||||
label: str
|
||||
cluster: str
|
||||
size: float
|
||||
|
||||
x: float
|
||||
y: float
|
||||
z: float | None = None
|
||||
|
||||
def to_pandas(self) -> tuple[str, float, float, str, float]:
|
||||
"""To pandas method definition."""
|
||||
return self.label, self.x, self.y, self.cluster, self.size
|
||||
|
||||
|
||||
GraphLayout = list[NodePosition]
|
||||
136
graphrag/index/operations/layout_graph/umap.py
Normal file
136
graphrag/index/operations/layout_graph/umap.py
Normal file
@ -0,0 +1,136 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing run and _create_node_position methods definitions."""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
from graphrag.index.operations.embed_graph.typing import NodeEmbeddings
|
||||
from graphrag.index.operations.layout_graph.typing import (
|
||||
GraphLayout,
|
||||
NodePosition,
|
||||
)
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
|
||||
# TODO: This could be handled more elegantly, like what columns to use
|
||||
# for "size" or "cluster"
|
||||
# We could also have a boolean to indicate to use node sizes or clusters
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run(
|
||||
graph: nx.Graph,
|
||||
embeddings: NodeEmbeddings,
|
||||
args: dict[str, Any],
|
||||
on_error: ErrorHandlerFn,
|
||||
) -> GraphLayout:
|
||||
"""Run method definition."""
|
||||
node_clusters = []
|
||||
node_sizes = []
|
||||
|
||||
embeddings = _filter_raw_embeddings(embeddings)
|
||||
nodes = list(embeddings.keys())
|
||||
embedding_vectors = [embeddings[node_id] for node_id in nodes]
|
||||
|
||||
for node_id in nodes:
|
||||
node = graph.nodes[node_id]
|
||||
cluster = node.get("cluster", node.get("community", -1))
|
||||
node_clusters.append(cluster)
|
||||
size = node.get("degree", node.get("size", 0))
|
||||
node_sizes.append(size)
|
||||
|
||||
additional_args = {}
|
||||
if len(node_clusters) > 0:
|
||||
additional_args["node_categories"] = node_clusters
|
||||
if len(node_sizes) > 0:
|
||||
additional_args["node_sizes"] = node_sizes
|
||||
|
||||
try:
|
||||
return compute_umap_positions(
|
||||
embedding_vectors=np.array(embedding_vectors),
|
||||
node_labels=nodes,
|
||||
**additional_args,
|
||||
min_dist=args.get("min_dist", 0.75),
|
||||
n_neighbors=args.get("n_neighbors", 5),
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception("Error running UMAP")
|
||||
on_error(e, traceback.format_exc(), None)
|
||||
# Umap may fail due to input sparseness or memory pressure.
|
||||
# For now, in these cases, we'll just return a layout with all nodes at (0, 0)
|
||||
result = []
|
||||
for i in range(len(nodes)):
|
||||
cluster = node_clusters[i] if len(node_clusters) > 0 else 1
|
||||
result.append(
|
||||
NodePosition(x=0, y=0, label=nodes[i], size=0, cluster=str(cluster))
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _filter_raw_embeddings(embeddings: NodeEmbeddings) -> NodeEmbeddings:
|
||||
return {
|
||||
node_id: embedding
|
||||
for node_id, embedding in embeddings.items()
|
||||
if embedding is not None
|
||||
}
|
||||
|
||||
|
||||
def compute_umap_positions(
|
||||
embedding_vectors: np.ndarray,
|
||||
node_labels: list[str],
|
||||
node_categories: list[int] | None = None,
|
||||
node_sizes: list[int] | None = None,
|
||||
min_dist: float = 0.75,
|
||||
n_neighbors: int = 25,
|
||||
spread: int = 1,
|
||||
metric: str = "euclidean",
|
||||
n_components: int = 2,
|
||||
random_state: int = 86,
|
||||
) -> list[NodePosition]:
|
||||
"""Project embedding vectors down to 2D/3D using UMAP."""
|
||||
# NOTE: This import is done here to reduce the initial import time of the graphrag package
|
||||
import umap
|
||||
|
||||
embedding_positions = umap.UMAP(
|
||||
min_dist=min_dist,
|
||||
n_neighbors=n_neighbors,
|
||||
spread=spread,
|
||||
n_components=n_components,
|
||||
metric=metric,
|
||||
random_state=random_state,
|
||||
).fit_transform(embedding_vectors)
|
||||
|
||||
embedding_position_data: list[NodePosition] = []
|
||||
for index, node_name in enumerate(node_labels):
|
||||
node_points = embedding_positions[index] # type: ignore
|
||||
node_category = 1 if node_categories is None else node_categories[index]
|
||||
node_size = 1 if node_sizes is None else node_sizes[index]
|
||||
|
||||
if len(node_points) == 2:
|
||||
embedding_position_data.append(
|
||||
NodePosition(
|
||||
label=str(node_name),
|
||||
x=float(node_points[0]),
|
||||
y=float(node_points[1]),
|
||||
cluster=str(int(node_category)),
|
||||
size=int(node_size),
|
||||
)
|
||||
)
|
||||
else:
|
||||
embedding_position_data.append(
|
||||
NodePosition(
|
||||
label=str(node_name),
|
||||
x=float(node_points[0]),
|
||||
y=float(node_points[1]),
|
||||
z=float(node_points[2]),
|
||||
cluster=str(int(node_category)),
|
||||
size=int(node_size),
|
||||
)
|
||||
)
|
||||
return embedding_position_data
|
||||
98
graphrag/index/operations/layout_graph/zero.py
Normal file
98
graphrag/index/operations/layout_graph/zero.py
Normal file
@ -0,0 +1,98 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing run and _create_node_position methods definitions."""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from graphrag.index.operations.layout_graph.typing import (
|
||||
GraphLayout,
|
||||
NodePosition,
|
||||
)
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
|
||||
# TODO: This could be handled more elegantly, like what columns to use
|
||||
# for "size" or "cluster"
|
||||
# We could also have a boolean to indicate to use node sizes or clusters
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run(
|
||||
graph: nx.Graph,
|
||||
_args: dict[str, Any],
|
||||
on_error: ErrorHandlerFn,
|
||||
) -> GraphLayout:
|
||||
"""Run method definition."""
|
||||
node_clusters = []
|
||||
node_sizes = []
|
||||
|
||||
nodes = list(graph.nodes)
|
||||
|
||||
for node_id in nodes:
|
||||
node = graph.nodes[node_id]
|
||||
cluster = node.get("cluster", node.get("community", -1))
|
||||
node_clusters.append(cluster)
|
||||
size = node.get("degree", node.get("size", 0))
|
||||
node_sizes.append(size)
|
||||
|
||||
additional_args = {}
|
||||
if len(node_clusters) > 0:
|
||||
additional_args["node_categories"] = node_clusters
|
||||
if len(node_sizes) > 0:
|
||||
additional_args["node_sizes"] = node_sizes
|
||||
|
||||
try:
|
||||
return get_zero_positions(node_labels=nodes, **additional_args)
|
||||
except Exception as e:
|
||||
log.exception("Error running zero-position")
|
||||
on_error(e, traceback.format_exc(), None)
|
||||
# Umap may fail due to input sparseness or memory pressure.
|
||||
# For now, in these cases, we'll just return a layout with all nodes at (0, 0)
|
||||
result = []
|
||||
for i in range(len(nodes)):
|
||||
cluster = node_clusters[i] if len(node_clusters) > 0 else 1
|
||||
result.append(
|
||||
NodePosition(x=0, y=0, label=nodes[i], size=0, cluster=str(cluster))
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def get_zero_positions(
|
||||
node_labels: list[str],
|
||||
node_categories: list[int] | None = None,
|
||||
node_sizes: list[int] | None = None,
|
||||
three_d: bool | None = False,
|
||||
) -> list[NodePosition]:
|
||||
"""Project embedding vectors down to 2D/3D using UMAP."""
|
||||
embedding_position_data: list[NodePosition] = []
|
||||
for index, node_name in enumerate(node_labels):
|
||||
node_category = 1 if node_categories is None else node_categories[index]
|
||||
node_size = 1 if node_sizes is None else node_sizes[index]
|
||||
|
||||
if not three_d:
|
||||
embedding_position_data.append(
|
||||
NodePosition(
|
||||
label=str(node_name),
|
||||
x=0,
|
||||
y=0,
|
||||
cluster=str(int(node_category)),
|
||||
size=int(node_size),
|
||||
)
|
||||
)
|
||||
else:
|
||||
embedding_position_data.append(
|
||||
NodePosition(
|
||||
label=str(node_name),
|
||||
x=0,
|
||||
y=0,
|
||||
z=0,
|
||||
cluster=str(int(node_category)),
|
||||
size=int(node_size),
|
||||
)
|
||||
)
|
||||
return embedding_position_data
|
||||
@ -0,0 +1,26 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""The Indexing Engine community reports package root."""
|
||||
|
||||
import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas
|
||||
from graphrag.index.operations.summarize_communities.community_reports_extractor.build_mixed_context import (
|
||||
build_mixed_context,
|
||||
)
|
||||
from graphrag.index.operations.summarize_communities.community_reports_extractor.community_reports_extractor import (
|
||||
CommunityReportsExtractor,
|
||||
)
|
||||
from graphrag.index.operations.summarize_communities.community_reports_extractor.prep_community_report_context import (
|
||||
prep_community_report_context,
|
||||
)
|
||||
from graphrag.index.operations.summarize_communities.community_reports_extractor.sort_context import (
|
||||
sort_context,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CommunityReportsExtractor",
|
||||
"build_mixed_context",
|
||||
"prep_community_report_context",
|
||||
"schemas",
|
||||
"sort_context",
|
||||
]
|
||||
@ -0,0 +1,70 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""A module containing the build_mixed_context method definition."""
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas
|
||||
from graphrag.index.operations.summarize_communities.community_reports_extractor.sort_context import (
|
||||
sort_context,
|
||||
)
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
|
||||
|
||||
def build_mixed_context(context: list[dict], max_tokens: int) -> str:
|
||||
"""
|
||||
Build parent context by concatenating all sub-communities' contexts.
|
||||
|
||||
If the context exceeds the limit, we use sub-community reports instead.
|
||||
"""
|
||||
sorted_context = sorted(
|
||||
context, key=lambda x: x[schemas.CONTEXT_SIZE], reverse=True
|
||||
)
|
||||
|
||||
# replace local context with sub-community reports, starting from the biggest sub-community
|
||||
substitute_reports = []
|
||||
final_local_contexts = []
|
||||
exceeded_limit = True
|
||||
context_string = ""
|
||||
|
||||
for idx, sub_community_context in enumerate(sorted_context):
|
||||
if exceeded_limit:
|
||||
if sub_community_context[schemas.FULL_CONTENT]:
|
||||
substitute_reports.append({
|
||||
schemas.COMMUNITY_ID: sub_community_context[schemas.SUB_COMMUNITY],
|
||||
schemas.FULL_CONTENT: sub_community_context[schemas.FULL_CONTENT],
|
||||
})
|
||||
else:
|
||||
# this sub-community has no report, so we will use its local context
|
||||
final_local_contexts.extend(sub_community_context[schemas.ALL_CONTEXT])
|
||||
continue
|
||||
|
||||
# add local context for the remaining sub-communities
|
||||
remaining_local_context = []
|
||||
for rid in range(idx + 1, len(sorted_context)):
|
||||
remaining_local_context.extend(sorted_context[rid][schemas.ALL_CONTEXT])
|
||||
new_context_string = sort_context(
|
||||
local_context=remaining_local_context + final_local_contexts,
|
||||
sub_community_reports=substitute_reports,
|
||||
)
|
||||
if num_tokens(new_context_string) <= max_tokens:
|
||||
exceeded_limit = False
|
||||
context_string = new_context_string
|
||||
break
|
||||
|
||||
if exceeded_limit:
|
||||
# if all sub-community reports exceed the limit, we add reports until context is full
|
||||
substitute_reports = []
|
||||
for sub_community_context in sorted_context:
|
||||
substitute_reports.append({
|
||||
schemas.COMMUNITY_ID: sub_community_context[schemas.SUB_COMMUNITY],
|
||||
schemas.FULL_CONTENT: sub_community_context[schemas.FULL_CONTENT],
|
||||
})
|
||||
new_context_string = pd.DataFrame(substitute_reports).to_csv(
|
||||
index=False, sep=","
|
||||
)
|
||||
if num_tokens(new_context_string) > max_tokens:
|
||||
break
|
||||
|
||||
context_string = new_context_string
|
||||
return context_string
|
||||
@ -0,0 +1,106 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'CommunityReportsResult' and 'CommunityReportsExtractor' models."""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from fnllm import ChatLLM
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
from graphrag.prompts.index.community_report import COMMUNITY_REPORT_PROMPT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FindingModel(BaseModel):
|
||||
"""A model for the expected LLM response shape."""
|
||||
|
||||
summary: str = Field(description="The summary of the finding.")
|
||||
explanation: str = Field(description="An explanation of the finding.")
|
||||
|
||||
|
||||
class CommunityReportResponse(BaseModel):
|
||||
"""A model for the expected LLM response shape."""
|
||||
|
||||
title: str = Field(description="The title of the report.")
|
||||
summary: str = Field(description="A summary of the report.")
|
||||
findings: list[FindingModel] = Field(
|
||||
description="A list of findings in the report."
|
||||
)
|
||||
rating: float = Field(description="The rating of the report.")
|
||||
rating_explanation: str = Field(description="An explanation of the rating.")
|
||||
|
||||
extra_attributes: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Extra attributes."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommunityReportsResult:
|
||||
"""Community reports result class definition."""
|
||||
|
||||
output: str
|
||||
structured_output: CommunityReportResponse | None
|
||||
|
||||
|
||||
class CommunityReportsExtractor:
|
||||
"""Community reports extractor class definition."""
|
||||
|
||||
_llm: ChatLLM
|
||||
_input_text_key: str
|
||||
_extraction_prompt: str
|
||||
_output_formatter_prompt: str
|
||||
_on_error: ErrorHandlerFn
|
||||
_max_report_length: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_invoker: ChatLLM,
|
||||
input_text_key: str | None = None,
|
||||
extraction_prompt: str | None = None,
|
||||
on_error: ErrorHandlerFn | None = None,
|
||||
max_report_length: int | None = None,
|
||||
):
|
||||
"""Init method definition."""
|
||||
self._llm = llm_invoker
|
||||
self._input_text_key = input_text_key or "input_text"
|
||||
self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT
|
||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
||||
self._max_report_length = max_report_length or 1500
|
||||
|
||||
async def __call__(self, inputs: dict[str, Any]):
|
||||
"""Call method definition."""
|
||||
output = None
|
||||
try:
|
||||
input_text = inputs[self._input_text_key]
|
||||
prompt = self._extraction_prompt.replace(
|
||||
"{" + self._input_text_key + "}", input_text
|
||||
)
|
||||
response = await self._llm(
|
||||
prompt,
|
||||
json=True,
|
||||
name="create_community_report",
|
||||
json_model=CommunityReportResponse,
|
||||
model_parameters={"max_tokens": self._max_report_length},
|
||||
)
|
||||
output = response.parsed_json
|
||||
except Exception as e:
|
||||
log.exception("error generating community report")
|
||||
self._on_error(e, traceback.format_exc(), None)
|
||||
|
||||
text_output = self._get_text_output(output) if output else ""
|
||||
return CommunityReportsResult(
|
||||
structured_output=output,
|
||||
output=text_output,
|
||||
)
|
||||
|
||||
def _get_text_output(self, report: CommunityReportResponse) -> str:
|
||||
report_sections = "\n\n".join(
|
||||
f"## {f.summary}\n\n{f.explanation}" for f in report.findings
|
||||
)
|
||||
return f"# {report.title}\n\n{report.summary}\n\n{report_sections}"
|
||||
@ -0,0 +1,185 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing create_community_reports and load_strategy methods definition."""
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas
|
||||
from graphrag.index.operations.summarize_communities.community_reports_extractor.build_mixed_context import (
|
||||
build_mixed_context,
|
||||
)
|
||||
from graphrag.index.operations.summarize_communities.community_reports_extractor.sort_context import (
|
||||
sort_context,
|
||||
)
|
||||
from graphrag.index.utils.dataframes import (
|
||||
antijoin,
|
||||
drop_columns,
|
||||
join,
|
||||
select,
|
||||
transform_series,
|
||||
union,
|
||||
where_column_equals,
|
||||
)
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prep_community_report_context(
|
||||
report_df: pd.DataFrame | None,
|
||||
community_hierarchy_df: pd.DataFrame,
|
||||
local_context_df: pd.DataFrame,
|
||||
level: int,
|
||||
max_tokens: int,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Prep context for each community in a given level.
|
||||
|
||||
For each community:
|
||||
- Check if local context fits within the limit, if yes use local context
|
||||
- If local context exceeds the limit, iteratively replace local context with sub-community reports, starting from the biggest sub-community
|
||||
"""
|
||||
if report_df is None:
|
||||
report_df = pd.DataFrame()
|
||||
|
||||
# Filter by community level
|
||||
level_context_df = local_context_df.loc[
|
||||
local_context_df.loc[:, schemas.COMMUNITY_LEVEL] == level
|
||||
]
|
||||
|
||||
# Filter valid and invalid contexts using boolean logic
|
||||
valid_context_df = level_context_df.loc[
|
||||
~level_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG]
|
||||
]
|
||||
invalid_context_df = level_context_df.loc[
|
||||
level_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG]
|
||||
]
|
||||
|
||||
# there is no report to substitute with, so we just trim the local context of the invalid context records
|
||||
# this case should only happen at the bottom level of the community hierarchy where there are no sub-communities
|
||||
if invalid_context_df.empty:
|
||||
return valid_context_df
|
||||
|
||||
if report_df.empty:
|
||||
invalid_context_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
|
||||
invalid_context_df, max_tokens
|
||||
)
|
||||
invalid_context_df[schemas.CONTEXT_SIZE] = invalid_context_df.loc[
|
||||
:, schemas.CONTEXT_STRING
|
||||
].map(num_tokens)
|
||||
invalid_context_df[schemas.CONTEXT_EXCEED_FLAG] = 0
|
||||
return union(valid_context_df, invalid_context_df)
|
||||
|
||||
level_context_df = _antijoin_reports(level_context_df, report_df)
|
||||
|
||||
# for each invalid context, we will try to substitute with sub-community reports
|
||||
# first get local context and report (if available) for each sub-community
|
||||
sub_context_df = _get_subcontext_df(level + 1, report_df, local_context_df)
|
||||
community_df = _get_community_df(
|
||||
level, invalid_context_df, sub_context_df, community_hierarchy_df, max_tokens
|
||||
)
|
||||
|
||||
# handle any remaining invalid records that can't be subsituted with sub-community reports
|
||||
# this should be rare, but if it happens, we will just trim the local context to fit the limit
|
||||
remaining_df = _antijoin_reports(invalid_context_df, community_df)
|
||||
remaining_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
|
||||
remaining_df, max_tokens
|
||||
)
|
||||
|
||||
result = union(valid_context_df, community_df, remaining_df)
|
||||
result[schemas.CONTEXT_SIZE] = result.loc[:, schemas.CONTEXT_STRING].map(num_tokens)
|
||||
|
||||
result[schemas.CONTEXT_EXCEED_FLAG] = 0
|
||||
return result
|
||||
|
||||
|
||||
def _drop_community_level(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Drop the community level column from the dataframe."""
|
||||
return drop_columns(df, schemas.COMMUNITY_LEVEL)
|
||||
|
||||
|
||||
def _at_level(level: int, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Return records at the given level."""
|
||||
return where_column_equals(df, schemas.COMMUNITY_LEVEL, level)
|
||||
|
||||
|
||||
def _antijoin_reports(df: pd.DataFrame, reports: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Return records in df that are not in reports."""
|
||||
return antijoin(df, reports, schemas.NODE_COMMUNITY)
|
||||
|
||||
|
||||
def _sort_and_trim_context(df: pd.DataFrame, max_tokens: int) -> pd.Series:
|
||||
"""Sort and trim context to fit the limit."""
|
||||
series = cast("pd.Series", df[schemas.ALL_CONTEXT])
|
||||
return transform_series(series, lambda x: sort_context(x, max_tokens=max_tokens))
|
||||
|
||||
|
||||
def _build_mixed_context(df: pd.DataFrame, max_tokens: int) -> pd.Series:
|
||||
"""Sort and trim context to fit the limit."""
|
||||
series = cast("pd.Series", df[schemas.ALL_CONTEXT])
|
||||
return transform_series(
|
||||
series, lambda x: build_mixed_context(x, max_tokens=max_tokens)
|
||||
)
|
||||
|
||||
|
||||
def _get_subcontext_df(
|
||||
level: int, report_df: pd.DataFrame, local_context_df: pd.DataFrame
|
||||
) -> pd.DataFrame:
|
||||
"""Get sub-community context for each community."""
|
||||
sub_report_df = _drop_community_level(_at_level(level, report_df))
|
||||
sub_context_df = _at_level(level, local_context_df)
|
||||
sub_context_df = join(sub_context_df, sub_report_df, schemas.NODE_COMMUNITY)
|
||||
sub_context_df.rename(
|
||||
columns={schemas.NODE_COMMUNITY: schemas.SUB_COMMUNITY}, inplace=True
|
||||
)
|
||||
return sub_context_df
|
||||
|
||||
|
||||
def _get_community_df(
|
||||
level: int,
|
||||
invalid_context_df: pd.DataFrame,
|
||||
sub_context_df: pd.DataFrame,
|
||||
community_hierarchy_df: pd.DataFrame,
|
||||
max_tokens: int,
|
||||
) -> pd.DataFrame:
|
||||
"""Get community context for each community."""
|
||||
# collect all sub communities' contexts for each community
|
||||
community_df = _drop_community_level(_at_level(level, community_hierarchy_df))
|
||||
invalid_community_ids = select(invalid_context_df, schemas.NODE_COMMUNITY)
|
||||
subcontext_selection = select(
|
||||
sub_context_df,
|
||||
schemas.SUB_COMMUNITY,
|
||||
schemas.FULL_CONTENT,
|
||||
schemas.ALL_CONTEXT,
|
||||
schemas.CONTEXT_SIZE,
|
||||
)
|
||||
|
||||
invalid_communities = join(
|
||||
community_df, invalid_community_ids, schemas.NODE_COMMUNITY, "inner"
|
||||
)
|
||||
community_df = join(
|
||||
invalid_communities, subcontext_selection, schemas.SUB_COMMUNITY
|
||||
)
|
||||
community_df[schemas.ALL_CONTEXT] = community_df.apply(
|
||||
lambda x: {
|
||||
schemas.SUB_COMMUNITY: x[schemas.SUB_COMMUNITY],
|
||||
schemas.ALL_CONTEXT: x[schemas.ALL_CONTEXT],
|
||||
schemas.FULL_CONTENT: x[schemas.FULL_CONTENT],
|
||||
schemas.CONTEXT_SIZE: x[schemas.CONTEXT_SIZE],
|
||||
},
|
||||
axis=1,
|
||||
)
|
||||
community_df = (
|
||||
community_df.groupby(schemas.NODE_COMMUNITY)
|
||||
.agg({schemas.ALL_CONTEXT: list})
|
||||
.reset_index()
|
||||
)
|
||||
community_df[schemas.CONTEXT_STRING] = _build_mixed_context(
|
||||
community_df, max_tokens
|
||||
)
|
||||
community_df[schemas.COMMUNITY_LEVEL] = level
|
||||
return community_df
|
||||
@ -0,0 +1,52 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""Common field name definitions for community reports."""
|
||||
|
||||
# POST-PREP NODE TABLE SCHEMA
|
||||
NODE_ID = "human_readable_id"
|
||||
NODE_NAME = "title"
|
||||
NODE_DESCRIPTION = "description"
|
||||
NODE_DEGREE = "degree"
|
||||
NODE_DETAILS = "node_details"
|
||||
NODE_COMMUNITY = "community"
|
||||
NODE_LEVEL = "level"
|
||||
|
||||
# POST-PREP EDGE TABLE SCHEMA
|
||||
EDGE_ID = "human_readable_id"
|
||||
EDGE_SOURCE = "source"
|
||||
EDGE_TARGET = "target"
|
||||
EDGE_DESCRIPTION = "description"
|
||||
EDGE_DEGREE = "combined_degree"
|
||||
EDGE_DETAILS = "edge_details"
|
||||
EDGE_WEIGHT = "weight"
|
||||
|
||||
# POST-PREP CLAIM TABLE SCHEMA
|
||||
CLAIM_ID = "human_readable_id"
|
||||
CLAIM_SUBJECT = "subject_id"
|
||||
CLAIM_TYPE = "type"
|
||||
CLAIM_STATUS = "status"
|
||||
CLAIM_DESCRIPTION = "description"
|
||||
CLAIM_DETAILS = "claim_details"
|
||||
|
||||
# COMMUNITY HIERARCHY TABLE SCHEMA
|
||||
SUB_COMMUNITY = "sub_community"
|
||||
SUB_COMMUNITY_SIZE = "sub_community_size"
|
||||
COMMUNITY_LEVEL = "level"
|
||||
|
||||
# COMMUNITY CONTEXT TABLE SCHEMA
|
||||
ALL_CONTEXT = "all_context"
|
||||
CONTEXT_STRING = "context_string"
|
||||
CONTEXT_SIZE = "context_size"
|
||||
CONTEXT_EXCEED_FLAG = "context_exceed_limit"
|
||||
|
||||
# COMMUNITY REPORT TABLE SCHEMA
|
||||
REPORT_ID = "id"
|
||||
COMMUNITY_ID = "community"
|
||||
COMMUNITY_LEVEL = "level"
|
||||
TITLE = "title"
|
||||
SUMMARY = "summary"
|
||||
FINDINGS = "findings"
|
||||
RATING = "rank"
|
||||
EXPLANATION = "rating_explanation"
|
||||
FULL_CONTENT = "full_content"
|
||||
FULL_CONTENT_JSON = "full_content_json"
|
||||
@ -0,0 +1,154 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""Sort context by degree in descending order."""
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
|
||||
|
||||
def sort_context(
|
||||
local_context: list[dict],
|
||||
sub_community_reports: list[dict] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
node_name_column: str = schemas.NODE_NAME,
|
||||
node_details_column: str = schemas.NODE_DETAILS,
|
||||
edge_id_column: str = schemas.EDGE_ID,
|
||||
edge_details_column: str = schemas.EDGE_DETAILS,
|
||||
edge_degree_column: str = schemas.EDGE_DEGREE,
|
||||
edge_source_column: str = schemas.EDGE_SOURCE,
|
||||
edge_target_column: str = schemas.EDGE_TARGET,
|
||||
claim_details_column: str = schemas.CLAIM_DETAILS,
|
||||
) -> str:
|
||||
"""Sort context by degree in descending order, optimizing for performance."""
|
||||
|
||||
def _get_context_string(
|
||||
entities: list[dict],
|
||||
edges: list[dict],
|
||||
claims: list[dict],
|
||||
sub_community_reports: list[dict] | None = None,
|
||||
) -> str:
|
||||
"""Concatenate structured data into a context string."""
|
||||
contexts = []
|
||||
if sub_community_reports:
|
||||
report_df = pd.DataFrame(sub_community_reports)
|
||||
if not report_df.empty:
|
||||
contexts.append(
|
||||
f"----Reports-----\n{report_df.to_csv(index=False, sep=',')}"
|
||||
)
|
||||
|
||||
for label, data in [
|
||||
("Entities", entities),
|
||||
("Claims", claims),
|
||||
("Relationships", edges),
|
||||
]:
|
||||
if data:
|
||||
data_df = pd.DataFrame(data)
|
||||
if not data_df.empty:
|
||||
contexts.append(
|
||||
f"-----{label}-----\n{data_df.to_csv(index=False, sep=',')}"
|
||||
)
|
||||
|
||||
return "\n\n".join(contexts)
|
||||
|
||||
# Preprocess local context
|
||||
edges = [
|
||||
{**e, schemas.EDGE_ID: int(e[schemas.EDGE_ID])}
|
||||
for record in local_context
|
||||
for e in record.get(edge_details_column, [])
|
||||
if isinstance(e, dict)
|
||||
]
|
||||
|
||||
node_details = {
|
||||
record[node_name_column]: {
|
||||
**record[node_details_column],
|
||||
schemas.NODE_ID: int(record[node_details_column][schemas.NODE_ID]),
|
||||
}
|
||||
for record in local_context
|
||||
}
|
||||
|
||||
claim_details = {
|
||||
record[node_name_column]: [
|
||||
{**c, schemas.CLAIM_ID: int(c[schemas.CLAIM_ID])}
|
||||
for c in record.get(claim_details_column, [])
|
||||
if isinstance(c, dict) and c.get(schemas.CLAIM_ID) is not None
|
||||
]
|
||||
for record in local_context
|
||||
if isinstance(record.get(claim_details_column), list)
|
||||
}
|
||||
|
||||
# Sort edges by degree (desc) and ID (asc)
|
||||
edges.sort(key=lambda x: (-x.get(edge_degree_column, 0), x.get(edge_id_column, "")))
|
||||
|
||||
# Deduplicate and build context incrementally
|
||||
edge_ids, nodes_ids, claims_ids = set(), set(), set()
|
||||
sorted_edges, sorted_nodes, sorted_claims = [], [], []
|
||||
context_string = ""
|
||||
|
||||
for edge in edges:
|
||||
source, target = edge[edge_source_column], edge[edge_target_column]
|
||||
|
||||
# Add source and target node details
|
||||
for node in [node_details.get(source), node_details.get(target)]:
|
||||
if node and node[schemas.NODE_ID] not in nodes_ids:
|
||||
nodes_ids.add(node[schemas.NODE_ID])
|
||||
sorted_nodes.append(node)
|
||||
|
||||
# Add claims related to source and target
|
||||
for claims in [claim_details.get(source), claim_details.get(target)]:
|
||||
if claims:
|
||||
for claim in claims:
|
||||
if claim[schemas.CLAIM_ID] not in claims_ids:
|
||||
claims_ids.add(claim[schemas.CLAIM_ID])
|
||||
sorted_claims.append(claim)
|
||||
|
||||
# Add the edge
|
||||
if edge[schemas.EDGE_ID] not in edge_ids:
|
||||
edge_ids.add(edge[schemas.EDGE_ID])
|
||||
sorted_edges.append(edge)
|
||||
|
||||
# Generate new context string
|
||||
new_context_string = _get_context_string(
|
||||
sorted_nodes, sorted_edges, sorted_claims, sub_community_reports
|
||||
)
|
||||
if max_tokens and num_tokens(new_context_string) > max_tokens:
|
||||
break
|
||||
context_string = new_context_string
|
||||
|
||||
# Return the final context string
|
||||
return context_string or _get_context_string(
|
||||
sorted_nodes, sorted_edges, sorted_claims, sub_community_reports
|
||||
)
|
||||
|
||||
|
||||
def parallel_sort_context_batch(community_df, max_tokens, parallel=False):
|
||||
"""Calculate context using parallelization if enabled."""
|
||||
if parallel:
|
||||
# Use ThreadPoolExecutor for parallel execution
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
with ThreadPoolExecutor(max_workers=None) as executor:
|
||||
context_strings = list(
|
||||
executor.map(
|
||||
lambda x: sort_context(x, max_tokens=max_tokens),
|
||||
community_df[schemas.ALL_CONTEXT],
|
||||
)
|
||||
)
|
||||
community_df[schemas.CONTEXT_STRING] = context_strings
|
||||
|
||||
else:
|
||||
# Assign context strings directly to the DataFrame
|
||||
community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply(
|
||||
lambda context_list: sort_context(context_list, max_tokens=max_tokens)
|
||||
)
|
||||
|
||||
# Calculate other columns
|
||||
community_df[schemas.CONTEXT_SIZE] = community_df[schemas.CONTEXT_STRING].apply(
|
||||
num_tokens
|
||||
)
|
||||
community_df[schemas.CONTEXT_EXCEED_FLAG] = (
|
||||
community_df[schemas.CONTEXT_SIZE] > max_tokens
|
||||
)
|
||||
|
||||
return community_df
|
||||
@ -0,0 +1,15 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing community report generation utilities."""
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas
|
||||
|
||||
|
||||
def get_levels(df: pd.DataFrame, level_column: str = schemas.NODE_LEVEL) -> list[int]:
|
||||
"""Get the levels of the communities."""
|
||||
levels = df[level_column].dropna().unique()
|
||||
levels = [int(lvl) for lvl in levels if lvl != -1]
|
||||
return sorted(levels, reverse=True)
|
||||
@ -0,0 +1,140 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'GraphExtractionResult' and 'GraphExtractor' models."""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fnllm import ChatLLM
|
||||
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
from graphrag.index.utils.tokens import num_tokens_from_string
|
||||
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
|
||||
|
||||
# Max token size for input prompts
|
||||
DEFAULT_MAX_INPUT_TOKENS = 4_000
|
||||
# Max token count for LLM answers
|
||||
DEFAULT_MAX_SUMMARY_LENGTH = 500
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummarizationResult:
|
||||
"""Unipartite graph extraction result class definition."""
|
||||
|
||||
id: str | tuple[str, str]
|
||||
description: str
|
||||
|
||||
|
||||
class SummarizeExtractor:
|
||||
"""Unipartite graph extractor class definition."""
|
||||
|
||||
_llm: ChatLLM
|
||||
_entity_name_key: str
|
||||
_input_descriptions_key: str
|
||||
_summarization_prompt: str
|
||||
_on_error: ErrorHandlerFn
|
||||
_max_summary_length: int
|
||||
_max_input_tokens: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_invoker: ChatLLM,
|
||||
entity_name_key: str | None = None,
|
||||
input_descriptions_key: str | None = None,
|
||||
summarization_prompt: str | None = None,
|
||||
on_error: ErrorHandlerFn | None = None,
|
||||
max_summary_length: int | None = None,
|
||||
max_input_tokens: int | None = None,
|
||||
):
|
||||
"""Init method definition."""
|
||||
# TODO: streamline construction
|
||||
self._llm = llm_invoker
|
||||
self._entity_name_key = entity_name_key or "entity_name"
|
||||
self._input_descriptions_key = input_descriptions_key or "description_list"
|
||||
|
||||
self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT
|
||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
||||
self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH
|
||||
self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
id: str | tuple[str, str],
|
||||
descriptions: list[str],
|
||||
) -> SummarizationResult:
|
||||
"""Call method definition."""
|
||||
result = ""
|
||||
if len(descriptions) == 0:
|
||||
result = ""
|
||||
elif len(descriptions) == 1:
|
||||
result = descriptions[0]
|
||||
else:
|
||||
result = await self._summarize_descriptions(id, descriptions)
|
||||
|
||||
return SummarizationResult(
|
||||
id=id,
|
||||
description=result or "",
|
||||
)
|
||||
|
||||
async def _summarize_descriptions(
|
||||
self, id: str | tuple[str, str], descriptions: list[str]
|
||||
) -> str:
|
||||
"""Summarize descriptions into a single description."""
|
||||
sorted_id = sorted(id) if isinstance(id, list) else id
|
||||
|
||||
# Safety check, should always be a list
|
||||
if not isinstance(descriptions, list):
|
||||
descriptions = [descriptions]
|
||||
|
||||
# Sort description lists
|
||||
if len(descriptions) > 1:
|
||||
descriptions = sorted(descriptions)
|
||||
|
||||
# Iterate over descriptions, adding all until the max input tokens is reached
|
||||
usable_tokens = self._max_input_tokens - num_tokens_from_string(
|
||||
self._summarization_prompt
|
||||
)
|
||||
descriptions_collected = []
|
||||
result = ""
|
||||
|
||||
for i, description in enumerate(descriptions):
|
||||
usable_tokens -= num_tokens_from_string(description)
|
||||
descriptions_collected.append(description)
|
||||
|
||||
# If buffer is full, or all descriptions have been added, summarize
|
||||
if (usable_tokens < 0 and len(descriptions_collected) > 1) or (
|
||||
i == len(descriptions) - 1
|
||||
):
|
||||
# Calculate result (final or partial)
|
||||
result = await self._summarize_descriptions_with_llm(
|
||||
sorted_id, descriptions_collected
|
||||
)
|
||||
|
||||
# If we go for another loop, reset values to new
|
||||
if i != len(descriptions) - 1:
|
||||
descriptions_collected = [result]
|
||||
usable_tokens = (
|
||||
self._max_input_tokens
|
||||
- num_tokens_from_string(self._summarization_prompt)
|
||||
- num_tokens_from_string(result)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _summarize_descriptions_with_llm(
|
||||
self, id: str | tuple[str, str] | list[str], descriptions: list[str]
|
||||
):
|
||||
"""Summarize descriptions using the LLM."""
|
||||
response = await self._llm(
|
||||
self._summarization_prompt.format(**{
|
||||
self._entity_name_key: json.dumps(id, ensure_ascii=False),
|
||||
self._input_descriptions_key: json.dumps(
|
||||
sorted(descriptions), ensure_ascii=False
|
||||
),
|
||||
}),
|
||||
name="summarize",
|
||||
model_parameters={"max_tokens": self._max_summary_length},
|
||||
)
|
||||
# Calculate result
|
||||
return str(response.output.content)
|
||||
@ -0,0 +1,64 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing run_graph_intelligence, run_resolve_entities and _create_text_list_splitter methods to run graph intelligence."""
|
||||
|
||||
from datashaper import VerbCallbacks
|
||||
from fnllm import ChatLLM
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.index.llm.load_llm import load_llm, read_llm_params
|
||||
from graphrag.index.operations.summarize_descriptions.description_summary_extractor import (
|
||||
SummarizeExtractor,
|
||||
)
|
||||
from graphrag.index.operations.summarize_descriptions.typing import (
|
||||
StrategyConfig,
|
||||
SummarizedDescriptionResult,
|
||||
)
|
||||
|
||||
|
||||
async def run_graph_intelligence(
|
||||
id: str | tuple[str, str],
|
||||
descriptions: list[str],
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
args: StrategyConfig,
|
||||
) -> SummarizedDescriptionResult:
|
||||
"""Run the graph intelligence entity extraction strategy."""
|
||||
llm_config = read_llm_params(args.get("llm", {}))
|
||||
llm = load_llm(
|
||||
"summarize_descriptions", llm_config, callbacks=callbacks, cache=cache
|
||||
)
|
||||
return await run_summarize_descriptions(llm, id, descriptions, callbacks, args)
|
||||
|
||||
|
||||
async def run_summarize_descriptions(
|
||||
llm: ChatLLM,
|
||||
id: str | tuple[str, str],
|
||||
descriptions: list[str],
|
||||
callbacks: VerbCallbacks,
|
||||
args: StrategyConfig,
|
||||
) -> SummarizedDescriptionResult:
|
||||
"""Run the entity extraction chain."""
|
||||
# Extraction Arguments
|
||||
summarize_prompt = args.get("summarize_prompt", None)
|
||||
entity_name_key = args.get("entity_name_key", "entity_name")
|
||||
input_descriptions_key = args.get("input_descriptions_key", "description_list")
|
||||
max_tokens = args.get("max_tokens", None)
|
||||
|
||||
extractor = SummarizeExtractor(
|
||||
llm_invoker=llm,
|
||||
summarization_prompt=summarize_prompt,
|
||||
entity_name_key=entity_name_key,
|
||||
input_descriptions_key=input_descriptions_key,
|
||||
on_error=lambda e, stack, details: (
|
||||
callbacks.error("Entity Extraction Error", e, stack, details)
|
||||
if callbacks
|
||||
else None
|
||||
),
|
||||
max_summary_length=args.get("max_summary_length", None),
|
||||
max_input_tokens=max_tokens,
|
||||
)
|
||||
|
||||
result = await extractor(id=id, descriptions=descriptions)
|
||||
return SummarizedDescriptionResult(id=result.id, description=result.description)
|
||||
67
graphrag/index/utils/stable_lcc.py
Normal file
67
graphrag/index/utils/stable_lcc.py
Normal file
@ -0,0 +1,67 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module for producing a stable largest connected component, i.e. same input graph == same output lcc."""
|
||||
|
||||
import html
|
||||
from typing import Any, cast
|
||||
|
||||
import networkx as nx
|
||||
|
||||
|
||||
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
|
||||
"""Return the largest connected component of the graph, with nodes and edges sorted in a stable way."""
|
||||
# NOTE: The import is done here to reduce the initial import time of the module
|
||||
from graspologic.utils import largest_connected_component
|
||||
|
||||
graph = graph.copy()
|
||||
graph = cast("nx.Graph", largest_connected_component(graph))
|
||||
graph = normalize_node_names(graph)
|
||||
return _stabilize_graph(graph)
|
||||
|
||||
|
||||
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
||||
"""Ensure an undirected graph with the same relationships will always be read the same way."""
|
||||
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
|
||||
|
||||
sorted_nodes = graph.nodes(data=True)
|
||||
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
|
||||
|
||||
fixed_graph.add_nodes_from(sorted_nodes)
|
||||
edges = list(graph.edges(data=True))
|
||||
|
||||
# If the graph is undirected, we create the edges in a stable way, so we get the same results
|
||||
# for example:
|
||||
# A -> B
|
||||
# in graph theory is the same as
|
||||
# B -> A
|
||||
# in an undirected graph
|
||||
# however, this can lead to downstream issues because sometimes
|
||||
# consumers read graph.nodes() which ends up being [A, B] and sometimes it's [B, A]
|
||||
# but they base some of their logic on the order of the nodes, so the order ends up being important
|
||||
# so we sort the nodes in the edge in a stable way, so that we always get the same order
|
||||
if not graph.is_directed():
|
||||
|
||||
def _sort_source_target(edge):
|
||||
source, target, edge_data = edge
|
||||
if source > target:
|
||||
temp = source
|
||||
source = target
|
||||
target = temp
|
||||
return source, target, edge_data
|
||||
|
||||
edges = [_sort_source_target(edge) for edge in edges]
|
||||
|
||||
def _get_edge_key(source: Any, target: Any) -> str:
|
||||
return f"{source} -> {target}"
|
||||
|
||||
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
|
||||
|
||||
fixed_graph.add_edges_from(edges)
|
||||
return fixed_graph
|
||||
|
||||
|
||||
def normalize_node_names(graph: nx.Graph | nx.DiGraph) -> nx.Graph | nx.DiGraph:
|
||||
"""Normalize node names."""
|
||||
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
|
||||
return nx.relabel_nodes(graph, node_mapping)
|
||||
2
graphrag/py.typed
Normal file
2
graphrag/py.typed
Normal file
@ -0,0 +1,2 @@
|
||||
# This package supports type hinting,
|
||||
# see https://www.python.org/dev/peps/pep-0561/#packaging-type-information
|
||||
109
tests/integration/storage/test_cosmosdb_storage.py
Normal file
109
tests/integration/storage/test_cosmosdb_storage.py
Normal file
@ -0,0 +1,109 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""CosmosDB Storage Tests."""
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
|
||||
|
||||
# cspell:disable-next-line well-known-key
|
||||
WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
|
||||
|
||||
# the cosmosdb emulator is only available on windows runners at this time
|
||||
if not sys.platform.startswith("win"):
|
||||
pytest.skip(
|
||||
"encountered windows-only tests -- will skip for now", allow_module_level=True
|
||||
)
|
||||
|
||||
|
||||
async def test_find():
|
||||
storage = CosmosDBPipelineStorage(
|
||||
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
|
||||
database_name="testfind",
|
||||
container_name="testfindcontainer",
|
||||
)
|
||||
try:
|
||||
try:
|
||||
items = list(storage.find(file_pattern=re.compile(r".*\.json$")))
|
||||
items = [item[0] for item in items]
|
||||
assert items == []
|
||||
|
||||
json_content = {
|
||||
"content": "Merry Christmas!",
|
||||
}
|
||||
await storage.set(
|
||||
"christmas.json", json.dumps(json_content), encoding="utf-8"
|
||||
)
|
||||
items = list(storage.find(file_pattern=re.compile(r".*\.json$")))
|
||||
items = [item[0] for item in items]
|
||||
assert items == ["christmas.json"]
|
||||
|
||||
json_content = {
|
||||
"content": "Hello, World!",
|
||||
}
|
||||
await storage.set("test.json", json.dumps(json_content), encoding="utf-8")
|
||||
items = list(storage.find(file_pattern=re.compile(r".*\.json$")))
|
||||
items = [item[0] for item in items]
|
||||
assert items == ["christmas.json", "test.json"]
|
||||
|
||||
output = await storage.get("test.json")
|
||||
output_json = json.loads(output)
|
||||
assert output_json["content"] == "Hello, World!"
|
||||
|
||||
json_exists = await storage.has("christmas.json")
|
||||
assert json_exists is True
|
||||
json_exists = await storage.has("easter.json")
|
||||
assert json_exists is False
|
||||
finally:
|
||||
await storage.delete("test.json")
|
||||
output = await storage.get("test.json")
|
||||
assert output is None
|
||||
finally:
|
||||
await storage.clear()
|
||||
|
||||
|
||||
async def test_child():
|
||||
storage = CosmosDBPipelineStorage(
|
||||
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
|
||||
database_name="testchild",
|
||||
container_name="testchildcontainer",
|
||||
)
|
||||
try:
|
||||
child_storage = storage.child("child")
|
||||
assert type(child_storage) is CosmosDBPipelineStorage
|
||||
finally:
|
||||
await storage.clear()
|
||||
|
||||
|
||||
async def test_clear():
|
||||
storage = CosmosDBPipelineStorage(
|
||||
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
|
||||
database_name="testclear",
|
||||
container_name="testclearcontainer",
|
||||
)
|
||||
try:
|
||||
json_exists = {
|
||||
"content": "Merry Christmas!",
|
||||
}
|
||||
await storage.set("christmas.json", json.dumps(json_exists), encoding="utf-8")
|
||||
json_exists = {
|
||||
"content": "Happy Easter!",
|
||||
}
|
||||
await storage.set("easter.json", json.dumps(json_exists), encoding="utf-8")
|
||||
await storage.clear()
|
||||
|
||||
items = list(storage.find(file_pattern=re.compile(r".*\.json$")))
|
||||
items = [item[0] for item in items]
|
||||
assert items == []
|
||||
|
||||
output = await storage.get("easter.json")
|
||||
assert output is None
|
||||
|
||||
assert storage._container_client is None # noqa: SLF001
|
||||
assert storage._database_client is None # noqa: SLF001
|
||||
finally:
|
||||
await storage.clear()
|
||||
Loading…
Reference in New Issue
Block a user