partially resolved merge conflicts

This commit is contained in:
Kenny Zhang 2024-12-19 17:13:14 -05:00
parent 9ca67643b4
commit 14b1eccbff
27 changed files with 2197 additions and 0 deletions

22
.semversioner/1.0.1.json Normal file
View File

@ -0,0 +1,22 @@
{
"changes": [
{
"description": "Fix encoding model config parsing",
"type": "patch"
},
{
"description": "Fix exception on error callbacks",
"type": "patch"
},
{
"description": "Manage llm instances inside a cached singleton. Check for empty dfs after entity/relationship extraction",
"type": "patch"
},
{
"description": "Respect encoding_model option",
"type": "patch"
}
],
"created_at": "2024-12-18T23:12:52+00:00",
"version": "1.0.1"
}

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Implement cosmosdb storage option for cache and output"
}

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Streamline flows."
}

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Move extractor code to co-locate with operations."
}

View File

@ -0,0 +1,40 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""LLM Manager singleton."""
from functools import cache
from fnllm import ChatLLM, EmbeddingsLLM
@cache
class ChatLLMSingleton:
"""A singleton class for the chat LLM instances."""
def __init__(self):
self.llm_dict = {}
def set_llm(self, name, llm):
"""Add an LLM to the dictionary."""
self.llm_dict[name] = llm
def get_llm(self, name) -> ChatLLM | None:
"""Get an LLM from the dictionary."""
return self.llm_dict.get(name)
@cache
class EmbeddingsLLMSingleton:
"""A singleton class for the embeddings LLM instances."""
def __init__(self):
self.llm_dict = {}
def set_llm(self, name, llm):
"""Add an LLM to the dictionary."""
self.llm_dict[name] = llm
def get_llm(self, name) -> EmbeddingsLLM | None:
"""Get an LLM from the dictionary."""
return self.llm_dict.get(name)

View File

@ -0,0 +1,15 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing create_graph definition."""
import networkx as nx
import pandas as pd
def compute_degree(graph: nx.Graph) -> pd.DataFrame:
"""Create a new DataFrame with the degree of each node in the graph."""
return pd.DataFrame([
{"title": node, "degree": int(degree)}
for node, degree in graph.degree # type: ignore
])

View File

@ -0,0 +1,43 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Utilities to generate graph embeddings."""
from dataclasses import dataclass
import networkx as nx
import numpy as np
@dataclass
class NodeEmbeddings:
"""Node embeddings class definition."""
nodes: list[str]
embeddings: np.ndarray
def embed_node2vec(
graph: nx.Graph | nx.DiGraph,
dimensions: int = 1536,
num_walks: int = 10,
walk_length: int = 40,
window_size: int = 2,
iterations: int = 3,
random_seed: int = 86,
) -> NodeEmbeddings:
"""Generate node embeddings using Node2Vec."""
# NOTE: This import is done here to reduce the initial import time of the graphrag package
import graspologic as gc
# generate embedding
lcc_tensors = gc.embed.node2vec_embed( # type: ignore
graph=graph,
dimensions=dimensions,
window_size=window_size,
iterations=iterations,
num_walks=num_walks,
walk_length=walk_length,
random_seed=random_seed,
)
return NodeEmbeddings(embeddings=lcc_tensors[0], nodes=lcc_tensors[1])

View File

@ -0,0 +1,242 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'ClaimExtractorResult' and 'ClaimExtractor' models."""
import logging
import traceback
from dataclasses import dataclass
from typing import Any
import tiktoken
from fnllm import ChatLLM
import graphrag.config.defaults as defs
from graphrag.index.typing import ErrorHandlerFn
from graphrag.prompts.index.claim_extraction import (
CLAIM_EXTRACTION_PROMPT,
CONTINUE_PROMPT,
LOOP_PROMPT,
)
DEFAULT_TUPLE_DELIMITER = "<|>"
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
log = logging.getLogger(__name__)
@dataclass
class ClaimExtractorResult:
"""Claim extractor result class definition."""
output: list[dict]
source_docs: dict[str, Any]
class ClaimExtractor:
"""Claim extractor class definition."""
_llm: ChatLLM
_extraction_prompt: str
_summary_prompt: str
_output_formatter_prompt: str
_input_text_key: str
_input_entity_spec_key: str
_input_claim_description_key: str
_tuple_delimiter_key: str
_record_delimiter_key: str
_completion_delimiter_key: str
_max_gleanings: int
_on_error: ErrorHandlerFn
_loop_args: dict[str, Any]
def __init__(
self,
llm_invoker: ChatLLM,
extraction_prompt: str | None = None,
input_text_key: str | None = None,
input_entity_spec_key: str | None = None,
input_claim_description_key: str | None = None,
input_resolved_entities_key: str | None = None,
tuple_delimiter_key: str | None = None,
record_delimiter_key: str | None = None,
completion_delimiter_key: str | None = None,
encoding_model: str | None = None,
max_gleanings: int | None = None,
on_error: ErrorHandlerFn | None = None,
):
"""Init method definition."""
self._llm = llm_invoker
self._extraction_prompt = extraction_prompt or CLAIM_EXTRACTION_PROMPT
self._input_text_key = input_text_key or "input_text"
self._input_entity_spec_key = input_entity_spec_key or "entity_specs"
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
self._completion_delimiter_key = (
completion_delimiter_key or "completion_delimiter"
)
self._input_claim_description_key = (
input_claim_description_key or "claim_description"
)
self._input_resolved_entities_key = (
input_resolved_entities_key or "resolved_entities"
)
self._max_gleanings = (
max_gleanings if max_gleanings is not None else defs.CLAIM_MAX_GLEANINGS
)
self._on_error = on_error or (lambda _e, _s, _d: None)
# Construct the looping arguments
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
yes = f"{encoding.encode('YES')[0]}"
no = f"{encoding.encode('NO')[0]}"
self._loop_args = {"logit_bias": {yes: 100, no: 100}, "max_tokens": 1}
async def __call__(
self, inputs: dict[str, Any], prompt_variables: dict | None = None
) -> ClaimExtractorResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}
texts = inputs[self._input_text_key]
entity_spec = str(inputs[self._input_entity_spec_key])
claim_description = inputs[self._input_claim_description_key]
resolved_entities = inputs.get(self._input_resolved_entities_key, {})
source_doc_map = {}
prompt_args = {
self._input_entity_spec_key: entity_spec,
self._input_claim_description_key: claim_description,
self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
or DEFAULT_TUPLE_DELIMITER,
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
or DEFAULT_RECORD_DELIMITER,
self._completion_delimiter_key: prompt_variables.get(
self._completion_delimiter_key
)
or DEFAULT_COMPLETION_DELIMITER,
}
all_claims: list[dict] = []
for doc_index, text in enumerate(texts):
document_id = f"d{doc_index}"
try:
claims = await self._process_document(prompt_args, text, doc_index)
all_claims += [
self._clean_claim(c, document_id, resolved_entities) for c in claims
]
source_doc_map[document_id] = text
except Exception as e:
log.exception("error extracting claim")
self._on_error(
e,
traceback.format_exc(),
{"doc_index": doc_index, "text": text},
)
continue
return ClaimExtractorResult(
output=all_claims,
source_docs=source_doc_map,
)
def _clean_claim(
self, claim: dict, document_id: str, resolved_entities: dict
) -> dict:
# clean the parsed claims to remove any claims with status = False
obj = claim.get("object_id", claim.get("object"))
subject = claim.get("subject_id", claim.get("subject"))
# If subject or object in resolved entities, then replace with resolved entity
obj = resolved_entities.get(obj, obj)
subject = resolved_entities.get(subject, subject)
claim["object_id"] = obj
claim["subject_id"] = subject
return claim
async def _process_document(
self, prompt_args: dict, doc, doc_index: int
) -> list[dict]:
record_delimiter = prompt_args.get(
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
)
completion_delimiter = prompt_args.get(
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
)
response = await self._llm(
self._extraction_prompt.format(**{
self._input_text_key: doc,
**prompt_args,
})
)
results = response.output.content or ""
claims = results.strip().removesuffix(completion_delimiter)
# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
response = await self._llm(
CONTINUE_PROMPT,
name=f"extract-continuation-{i}",
history=response.history,
)
extension = response.output.content or ""
claims += record_delimiter + extension.strip().removesuffix(
completion_delimiter
)
# If this isn't the last loop, check to see if we should continue
if i >= self._max_gleanings - 1:
break
response = await self._llm(
LOOP_PROMPT,
name=f"extract-loopcheck-{i}",
history=response.history,
model_parameters=self._loop_args,
)
if response.output.content != "YES":
break
return self._parse_claim_tuples(results, prompt_args)
def _parse_claim_tuples(
self, claims: str, prompt_variables: dict
) -> list[dict[str, Any]]:
"""Parse claim tuples."""
record_delimiter = prompt_variables.get(
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
)
completion_delimiter = prompt_variables.get(
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
)
tuple_delimiter = prompt_variables.get(
self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER
)
def pull_field(index: int, fields: list[str]) -> str | None:
return fields[index].strip() if len(fields) > index else None
result: list[dict[str, Any]] = []
claims_values = (
claims.strip().removesuffix(completion_delimiter).split(record_delimiter)
)
for claim in claims_values:
claim = claim.strip().removeprefix("(").removesuffix(")")
# Ignore the completion delimiter
if claim == completion_delimiter:
continue
claim_fields = claim.split(tuple_delimiter)
result.append({
"subject_id": pull_field(0, claim_fields),
"object_id": pull_field(1, claim_fields),
"type": pull_field(2, claim_fields),
"status": pull_field(3, claim_fields),
"start_date": pull_field(4, claim_fields),
"end_date": pull_field(5, claim_fields),
"description": pull_field(6, claim_fields),
"source_text": pull_field(7, claim_fields),
})
return result

View File

@ -0,0 +1,308 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'GraphExtractionResult' and 'GraphExtractor' models."""
import logging
import re
import traceback
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any
import networkx as nx
import tiktoken
from fnllm import ChatLLM
import graphrag.config.defaults as defs
from graphrag.index.typing import ErrorHandlerFn
from graphrag.index.utils.string import clean_str
from graphrag.prompts.index.entity_extraction import (
CONTINUE_PROMPT,
GRAPH_EXTRACTION_PROMPT,
LOOP_PROMPT,
)
DEFAULT_TUPLE_DELIMITER = "<|>"
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
log = logging.getLogger(__name__)
@dataclass
class GraphExtractionResult:
"""Unipartite graph extraction result class definition."""
output: nx.Graph
source_docs: dict[Any, Any]
class GraphExtractor:
"""Unipartite graph extractor class definition."""
_llm: ChatLLM
_join_descriptions: bool
_tuple_delimiter_key: str
_record_delimiter_key: str
_entity_types_key: str
_input_text_key: str
_completion_delimiter_key: str
_entity_name_key: str
_input_descriptions_key: str
_extraction_prompt: str
_summarization_prompt: str
_loop_args: dict[str, Any]
_max_gleanings: int
_on_error: ErrorHandlerFn
def __init__(
self,
llm_invoker: ChatLLM,
tuple_delimiter_key: str | None = None,
record_delimiter_key: str | None = None,
input_text_key: str | None = None,
entity_types_key: str | None = None,
completion_delimiter_key: str | None = None,
prompt: str | None = None,
join_descriptions=True,
encoding_model: str | None = None,
max_gleanings: int | None = None,
on_error: ErrorHandlerFn | None = None,
):
"""Init method definition."""
# TODO: streamline construction
self._llm = llm_invoker
self._join_descriptions = join_descriptions
self._input_text_key = input_text_key or "input_text"
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
self._completion_delimiter_key = (
completion_delimiter_key or "completion_delimiter"
)
self._entity_types_key = entity_types_key or "entity_types"
self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT
self._max_gleanings = (
max_gleanings
if max_gleanings is not None
else defs.ENTITY_EXTRACTION_MAX_GLEANINGS
)
self._on_error = on_error or (lambda _e, _s, _d: None)
# Construct the looping arguments
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
yes = f"{encoding.encode('YES')[0]}"
no = f"{encoding.encode('NO')[0]}"
self._loop_args = {"logit_bias": {yes: 100, no: 100}, "max_tokens": 1}
async def __call__(
self, texts: list[str], prompt_variables: dict[str, Any] | None = None
) -> GraphExtractionResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}
all_records: dict[int, str] = {}
source_doc_map: dict[int, str] = {}
# Wire defaults into the prompt variables
prompt_variables = {
**prompt_variables,
self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
or DEFAULT_TUPLE_DELIMITER,
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
or DEFAULT_RECORD_DELIMITER,
self._completion_delimiter_key: prompt_variables.get(
self._completion_delimiter_key
)
or DEFAULT_COMPLETION_DELIMITER,
self._entity_types_key: ",".join(
prompt_variables[self._entity_types_key] or DEFAULT_ENTITY_TYPES
),
}
for doc_index, text in enumerate(texts):
try:
# Invoke the entity extraction
result = await self._process_document(text, prompt_variables)
source_doc_map[doc_index] = text
all_records[doc_index] = result
except Exception as e:
log.exception("error extracting graph")
self._on_error(
e,
traceback.format_exc(),
{
"doc_index": doc_index,
"text": text,
},
)
output = await self._process_results(
all_records,
prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER),
prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER),
)
return GraphExtractionResult(
output=output,
source_docs=source_doc_map,
)
async def _process_document(
self, text: str, prompt_variables: dict[str, str]
) -> str:
response = await self._llm(
self._extraction_prompt.format(**{
**prompt_variables,
self._input_text_key: text,
}),
)
results = response.output.content or ""
# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
response = await self._llm(
CONTINUE_PROMPT,
name=f"extract-continuation-{i}",
history=response.history,
)
results += response.output.content or ""
# if this is the final glean, don't bother updating the continuation flag
if i >= self._max_gleanings - 1:
break
response = await self._llm(
LOOP_PROMPT,
name=f"extract-loopcheck-{i}",
history=response.history,
model_parameters=self._loop_args,
)
if response.output != "YES":
break
return results
async def _process_results(
self,
results: dict[int, str],
tuple_delimiter: str,
record_delimiter: str,
) -> nx.Graph:
"""Parse the result string to create an undirected unipartite graph.
Args:
- results - dict of results from the extraction chain
- tuple_delimiter - delimiter between tuples in an output record, default is '<|>'
- record_delimiter - delimiter between records, default is '##'
Returns:
- output - unipartite graph in graphML format
"""
graph = nx.Graph()
for source_doc_id, extracted_data in results.items():
records = [r.strip() for r in extracted_data.split(record_delimiter)]
for record in records:
record = re.sub(r"^\(|\)$", "", record.strip())
record_attributes = record.split(tuple_delimiter)
if record_attributes[0] == '"entity"' and len(record_attributes) >= 4:
# add this record as a node in the G
entity_name = clean_str(record_attributes[1].upper())
entity_type = clean_str(record_attributes[2].upper())
entity_description = clean_str(record_attributes[3])
if entity_name in graph.nodes():
node = graph.nodes[entity_name]
if self._join_descriptions:
node["description"] = "\n".join(
list({
*_unpack_descriptions(node),
entity_description,
})
)
else:
if len(entity_description) > len(node["description"]):
node["description"] = entity_description
node["source_id"] = ", ".join(
list({
*_unpack_source_ids(node),
str(source_doc_id),
})
)
node["type"] = (
entity_type if entity_type != "" else node["type"]
)
else:
graph.add_node(
entity_name,
type=entity_type,
description=entity_description,
source_id=str(source_doc_id),
)
if (
record_attributes[0] == '"relationship"'
and len(record_attributes) >= 5
):
# add this record as edge
source = clean_str(record_attributes[1].upper())
target = clean_str(record_attributes[2].upper())
edge_description = clean_str(record_attributes[3])
edge_source_id = clean_str(str(source_doc_id))
try:
weight = float(record_attributes[-1])
except ValueError:
weight = 1.0
if source not in graph.nodes():
graph.add_node(
source,
type="",
description="",
source_id=edge_source_id,
)
if target not in graph.nodes():
graph.add_node(
target,
type="",
description="",
source_id=edge_source_id,
)
if graph.has_edge(source, target):
edge_data = graph.get_edge_data(source, target)
if edge_data is not None:
weight += edge_data["weight"]
if self._join_descriptions:
edge_description = "\n".join(
list({
*_unpack_descriptions(edge_data),
edge_description,
})
)
edge_source_id = ", ".join(
list({
*_unpack_source_ids(edge_data),
str(source_doc_id),
})
)
graph.add_edge(
source,
target,
weight=weight,
description=edge_description,
source_id=edge_source_id,
)
return graph
def _unpack_descriptions(data: Mapping) -> list[str]:
value = data.get("description", None)
return [] if value is None else value.split("\n")
def _unpack_source_ids(data: Mapping) -> list[str]:
value = data.get("source_id", None)
return [] if value is None else value.split(", ")

View File

@ -0,0 +1,139 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing run_graph_intelligence, run_extract_entities and _create_text_splitter methods to run graph intelligence."""
import networkx as nx
from datashaper import VerbCallbacks
from fnllm import ChatLLM
import graphrag.config.defaults as defs
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.llm.load_llm import load_llm, read_llm_params
from graphrag.index.operations.extract_entities.graph_extractor import GraphExtractor
from graphrag.index.operations.extract_entities.typing import (
Document,
EntityExtractionResult,
EntityTypes,
StrategyConfig,
)
from graphrag.index.text_splitting.text_splitting import (
NoopTextSplitter,
TextSplitter,
TokenTextSplitter,
)
async def run_graph_intelligence(
docs: list[Document],
entity_types: EntityTypes,
callbacks: VerbCallbacks,
cache: PipelineCache,
args: StrategyConfig,
) -> EntityExtractionResult:
"""Run the graph intelligence entity extraction strategy."""
llm_config = read_llm_params(args.get("llm", {}))
llm = load_llm("entity_extraction", llm_config, callbacks=callbacks, cache=cache)
return await run_extract_entities(llm, docs, entity_types, callbacks, args)
async def run_extract_entities(
llm: ChatLLM,
docs: list[Document],
entity_types: EntityTypes,
callbacks: VerbCallbacks | None,
args: StrategyConfig,
) -> EntityExtractionResult:
"""Run the entity extraction chain."""
encoding_name = args.get("encoding_name", "cl100k_base")
# Chunking Arguments
prechunked = args.get("prechunked", False)
chunk_size = args.get("chunk_size", defs.CHUNK_SIZE)
chunk_overlap = args.get("chunk_overlap", defs.CHUNK_OVERLAP)
# Extraction Arguments
tuple_delimiter = args.get("tuple_delimiter", None)
record_delimiter = args.get("record_delimiter", None)
completion_delimiter = args.get("completion_delimiter", None)
extraction_prompt = args.get("extraction_prompt", None)
encoding_model = args.get("encoding_name", None)
max_gleanings = args.get("max_gleanings", defs.ENTITY_EXTRACTION_MAX_GLEANINGS)
# note: We're not using UnipartiteGraphChain.from_params
# because we want to pass "timeout" to the llm_kwargs
text_splitter = _create_text_splitter(
prechunked, chunk_size, chunk_overlap, encoding_name
)
extractor = GraphExtractor(
llm_invoker=llm,
prompt=extraction_prompt,
encoding_model=encoding_model,
max_gleanings=max_gleanings,
on_error=lambda e, s, d: (
callbacks.error("Entity Extraction Error", e, s, d) if callbacks else None
),
)
text_list = [doc.text.strip() for doc in docs]
# If it's not pre-chunked, then re-chunk the input
if not prechunked:
text_list = text_splitter.split_text("\n".join(text_list))
results = await extractor(
list(text_list),
{
"entity_types": entity_types,
"tuple_delimiter": tuple_delimiter,
"record_delimiter": record_delimiter,
"completion_delimiter": completion_delimiter,
},
)
graph = results.output
# Map the "source_id" back to the "id" field
for _, node in graph.nodes(data=True): # type: ignore
if node is not None:
node["source_id"] = ",".join(
docs[int(id)].id for id in node["source_id"].split(",")
)
for _, _, edge in graph.edges(data=True): # type: ignore
if edge is not None:
edge["source_id"] = ",".join(
docs[int(id)].id for id in edge["source_id"].split(",")
)
entities = [
({"title": item[0], **(item[1] or {})})
for item in graph.nodes(data=True)
if item is not None
]
relationships = nx.to_pandas_edgelist(graph)
return EntityExtractionResult(entities, relationships, graph)
def _create_text_splitter(
prechunked: bool, chunk_size: int, chunk_overlap: int, encoding_name: str
) -> TextSplitter:
"""Create a text splitter for the extraction chain.
Args:
- prechunked - Whether the text is already chunked
- chunk_size - The size of each chunk
- chunk_overlap - The overlap between chunks
- encoding_name - The name of the encoding to use
Returns:
- output - A text splitter
"""
if prechunked:
return NoopTextSplitter()
return TokenTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
encoding_name=encoding_name,
)

View File

@ -0,0 +1,66 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing run method definition."""
import networkx as nx
import nltk
from datashaper import VerbCallbacks
from nltk.corpus import words
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.operations.extract_entities.typing import (
Document,
EntityExtractionResult,
EntityTypes,
StrategyConfig,
)
# Need to do this cause we're potentially multithreading, and nltk doesn't like that
words.ensure_loaded()
async def run( # noqa RUF029 async is required for interface
docs: list[Document],
entity_types: EntityTypes,
callbacks: VerbCallbacks, # noqa ARG001
cache: PipelineCache, # noqa ARG001
args: StrategyConfig, # noqa ARG001
) -> EntityExtractionResult:
"""Run method definition."""
entity_map = {}
graph = nx.Graph()
for doc in docs:
connected_entities = []
for chunk in nltk.ne_chunk(nltk.pos_tag(nltk.word_tokenize(doc.text))):
if hasattr(chunk, "label"):
entity_type = chunk.label().lower()
if entity_type in entity_types:
name = (" ".join(c[0] for c in chunk)).upper()
connected_entities.append(name)
if name not in entity_map:
entity_map[name] = entity_type
graph.add_node(
name, type=entity_type, description=name, source_id=doc.id
)
# connect the entities if they appear in the same document
if len(connected_entities) > 1:
for i in range(len(connected_entities)):
for j in range(i + 1, len(connected_entities)):
description = f"{connected_entities[i]} -> {connected_entities[j]}"
graph.add_edge(
connected_entities[i],
connected_entities[j],
description=description,
source_id=doc.id,
)
return EntityExtractionResult(
entities=[
{"type": entity_type, "title": name}
for name, entity_type in entity_map.items()
],
relationships=[],
graph=graph,
)

View File

@ -0,0 +1,59 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'Document' and 'EntityExtractionResult' models."""
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from enum import Enum
from typing import Any
import networkx as nx
from datashaper import VerbCallbacks
from graphrag.cache.pipeline_cache import PipelineCache
ExtractedEntity = dict[str, Any]
ExtractedRelationship = dict[str, Any]
StrategyConfig = dict[str, Any]
EntityTypes = list[str]
@dataclass
class Document:
"""Document class definition."""
text: str
id: str
@dataclass
class EntityExtractionResult:
"""Entity extraction result class definition."""
entities: list[ExtractedEntity]
relationships: list[ExtractedRelationship]
graph: nx.Graph | None
EntityExtractStrategy = Callable[
[
list[Document],
EntityTypes,
VerbCallbacks,
PipelineCache,
StrategyConfig,
],
Awaitable[EntityExtractionResult],
]
class ExtractEntityStrategyType(str, Enum):
"""ExtractEntityStrategyType class definition."""
graph_intelligence = "graph_intelligence"
nltk = "nltk"
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'

View File

@ -0,0 +1,27 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
# Use this for now instead of a wrapper
"""A module containing 'NodePosition' model."""
from dataclasses import dataclass
@dataclass
class NodePosition:
"""Node position class definition."""
label: str
cluster: str
size: float
x: float
y: float
z: float | None = None
def to_pandas(self) -> tuple[str, float, float, str, float]:
"""To pandas method definition."""
return self.label, self.x, self.y, self.cluster, self.size
GraphLayout = list[NodePosition]

View File

@ -0,0 +1,136 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing run and _create_node_position methods definitions."""
import logging
import traceback
from typing import Any
import networkx as nx
import numpy as np
from graphrag.index.operations.embed_graph.typing import NodeEmbeddings
from graphrag.index.operations.layout_graph.typing import (
GraphLayout,
NodePosition,
)
from graphrag.index.typing import ErrorHandlerFn
# TODO: This could be handled more elegantly, like what columns to use
# for "size" or "cluster"
# We could also have a boolean to indicate to use node sizes or clusters
log = logging.getLogger(__name__)
def run(
graph: nx.Graph,
embeddings: NodeEmbeddings,
args: dict[str, Any],
on_error: ErrorHandlerFn,
) -> GraphLayout:
"""Run method definition."""
node_clusters = []
node_sizes = []
embeddings = _filter_raw_embeddings(embeddings)
nodes = list(embeddings.keys())
embedding_vectors = [embeddings[node_id] for node_id in nodes]
for node_id in nodes:
node = graph.nodes[node_id]
cluster = node.get("cluster", node.get("community", -1))
node_clusters.append(cluster)
size = node.get("degree", node.get("size", 0))
node_sizes.append(size)
additional_args = {}
if len(node_clusters) > 0:
additional_args["node_categories"] = node_clusters
if len(node_sizes) > 0:
additional_args["node_sizes"] = node_sizes
try:
return compute_umap_positions(
embedding_vectors=np.array(embedding_vectors),
node_labels=nodes,
**additional_args,
min_dist=args.get("min_dist", 0.75),
n_neighbors=args.get("n_neighbors", 5),
)
except Exception as e:
log.exception("Error running UMAP")
on_error(e, traceback.format_exc(), None)
# Umap may fail due to input sparseness or memory pressure.
# For now, in these cases, we'll just return a layout with all nodes at (0, 0)
result = []
for i in range(len(nodes)):
cluster = node_clusters[i] if len(node_clusters) > 0 else 1
result.append(
NodePosition(x=0, y=0, label=nodes[i], size=0, cluster=str(cluster))
)
return result
def _filter_raw_embeddings(embeddings: NodeEmbeddings) -> NodeEmbeddings:
return {
node_id: embedding
for node_id, embedding in embeddings.items()
if embedding is not None
}
def compute_umap_positions(
embedding_vectors: np.ndarray,
node_labels: list[str],
node_categories: list[int] | None = None,
node_sizes: list[int] | None = None,
min_dist: float = 0.75,
n_neighbors: int = 25,
spread: int = 1,
metric: str = "euclidean",
n_components: int = 2,
random_state: int = 86,
) -> list[NodePosition]:
"""Project embedding vectors down to 2D/3D using UMAP."""
# NOTE: This import is done here to reduce the initial import time of the graphrag package
import umap
embedding_positions = umap.UMAP(
min_dist=min_dist,
n_neighbors=n_neighbors,
spread=spread,
n_components=n_components,
metric=metric,
random_state=random_state,
).fit_transform(embedding_vectors)
embedding_position_data: list[NodePosition] = []
for index, node_name in enumerate(node_labels):
node_points = embedding_positions[index] # type: ignore
node_category = 1 if node_categories is None else node_categories[index]
node_size = 1 if node_sizes is None else node_sizes[index]
if len(node_points) == 2:
embedding_position_data.append(
NodePosition(
label=str(node_name),
x=float(node_points[0]),
y=float(node_points[1]),
cluster=str(int(node_category)),
size=int(node_size),
)
)
else:
embedding_position_data.append(
NodePosition(
label=str(node_name),
x=float(node_points[0]),
y=float(node_points[1]),
z=float(node_points[2]),
cluster=str(int(node_category)),
size=int(node_size),
)
)
return embedding_position_data

View File

@ -0,0 +1,98 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing run and _create_node_position methods definitions."""
import logging
import traceback
from typing import Any
import networkx as nx
from graphrag.index.operations.layout_graph.typing import (
GraphLayout,
NodePosition,
)
from graphrag.index.typing import ErrorHandlerFn
# TODO: This could be handled more elegantly, like what columns to use
# for "size" or "cluster"
# We could also have a boolean to indicate to use node sizes or clusters
log = logging.getLogger(__name__)
def run(
graph: nx.Graph,
_args: dict[str, Any],
on_error: ErrorHandlerFn,
) -> GraphLayout:
"""Run method definition."""
node_clusters = []
node_sizes = []
nodes = list(graph.nodes)
for node_id in nodes:
node = graph.nodes[node_id]
cluster = node.get("cluster", node.get("community", -1))
node_clusters.append(cluster)
size = node.get("degree", node.get("size", 0))
node_sizes.append(size)
additional_args = {}
if len(node_clusters) > 0:
additional_args["node_categories"] = node_clusters
if len(node_sizes) > 0:
additional_args["node_sizes"] = node_sizes
try:
return get_zero_positions(node_labels=nodes, **additional_args)
except Exception as e:
log.exception("Error running zero-position")
on_error(e, traceback.format_exc(), None)
# Umap may fail due to input sparseness or memory pressure.
# For now, in these cases, we'll just return a layout with all nodes at (0, 0)
result = []
for i in range(len(nodes)):
cluster = node_clusters[i] if len(node_clusters) > 0 else 1
result.append(
NodePosition(x=0, y=0, label=nodes[i], size=0, cluster=str(cluster))
)
return result
def get_zero_positions(
node_labels: list[str],
node_categories: list[int] | None = None,
node_sizes: list[int] | None = None,
three_d: bool | None = False,
) -> list[NodePosition]:
"""Project embedding vectors down to 2D/3D using UMAP."""
embedding_position_data: list[NodePosition] = []
for index, node_name in enumerate(node_labels):
node_category = 1 if node_categories is None else node_categories[index]
node_size = 1 if node_sizes is None else node_sizes[index]
if not three_d:
embedding_position_data.append(
NodePosition(
label=str(node_name),
x=0,
y=0,
cluster=str(int(node_category)),
size=int(node_size),
)
)
else:
embedding_position_data.append(
NodePosition(
label=str(node_name),
x=0,
y=0,
z=0,
cluster=str(int(node_category)),
size=int(node_size),
)
)
return embedding_position_data

View File

@ -0,0 +1,26 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The Indexing Engine community reports package root."""
import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas
from graphrag.index.operations.summarize_communities.community_reports_extractor.build_mixed_context import (
build_mixed_context,
)
from graphrag.index.operations.summarize_communities.community_reports_extractor.community_reports_extractor import (
CommunityReportsExtractor,
)
from graphrag.index.operations.summarize_communities.community_reports_extractor.prep_community_report_context import (
prep_community_report_context,
)
from graphrag.index.operations.summarize_communities.community_reports_extractor.sort_context import (
sort_context,
)
__all__ = [
"CommunityReportsExtractor",
"build_mixed_context",
"prep_community_report_context",
"schemas",
"sort_context",
]

View File

@ -0,0 +1,70 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing the build_mixed_context method definition."""
import pandas as pd
import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas
from graphrag.index.operations.summarize_communities.community_reports_extractor.sort_context import (
sort_context,
)
from graphrag.query.llm.text_utils import num_tokens
def build_mixed_context(context: list[dict], max_tokens: int) -> str:
"""
Build parent context by concatenating all sub-communities' contexts.
If the context exceeds the limit, we use sub-community reports instead.
"""
sorted_context = sorted(
context, key=lambda x: x[schemas.CONTEXT_SIZE], reverse=True
)
# replace local context with sub-community reports, starting from the biggest sub-community
substitute_reports = []
final_local_contexts = []
exceeded_limit = True
context_string = ""
for idx, sub_community_context in enumerate(sorted_context):
if exceeded_limit:
if sub_community_context[schemas.FULL_CONTENT]:
substitute_reports.append({
schemas.COMMUNITY_ID: sub_community_context[schemas.SUB_COMMUNITY],
schemas.FULL_CONTENT: sub_community_context[schemas.FULL_CONTENT],
})
else:
# this sub-community has no report, so we will use its local context
final_local_contexts.extend(sub_community_context[schemas.ALL_CONTEXT])
continue
# add local context for the remaining sub-communities
remaining_local_context = []
for rid in range(idx + 1, len(sorted_context)):
remaining_local_context.extend(sorted_context[rid][schemas.ALL_CONTEXT])
new_context_string = sort_context(
local_context=remaining_local_context + final_local_contexts,
sub_community_reports=substitute_reports,
)
if num_tokens(new_context_string) <= max_tokens:
exceeded_limit = False
context_string = new_context_string
break
if exceeded_limit:
# if all sub-community reports exceed the limit, we add reports until context is full
substitute_reports = []
for sub_community_context in sorted_context:
substitute_reports.append({
schemas.COMMUNITY_ID: sub_community_context[schemas.SUB_COMMUNITY],
schemas.FULL_CONTENT: sub_community_context[schemas.FULL_CONTENT],
})
new_context_string = pd.DataFrame(substitute_reports).to_csv(
index=False, sep=","
)
if num_tokens(new_context_string) > max_tokens:
break
context_string = new_context_string
return context_string

View File

@ -0,0 +1,106 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'CommunityReportsResult' and 'CommunityReportsExtractor' models."""
import logging
import traceback
from dataclasses import dataclass
from typing import Any
from fnllm import ChatLLM
from pydantic import BaseModel, Field
from graphrag.index.typing import ErrorHandlerFn
from graphrag.prompts.index.community_report import COMMUNITY_REPORT_PROMPT
log = logging.getLogger(__name__)
class FindingModel(BaseModel):
"""A model for the expected LLM response shape."""
summary: str = Field(description="The summary of the finding.")
explanation: str = Field(description="An explanation of the finding.")
class CommunityReportResponse(BaseModel):
"""A model for the expected LLM response shape."""
title: str = Field(description="The title of the report.")
summary: str = Field(description="A summary of the report.")
findings: list[FindingModel] = Field(
description="A list of findings in the report."
)
rating: float = Field(description="The rating of the report.")
rating_explanation: str = Field(description="An explanation of the rating.")
extra_attributes: dict[str, Any] = Field(
default_factory=dict, description="Extra attributes."
)
@dataclass
class CommunityReportsResult:
"""Community reports result class definition."""
output: str
structured_output: CommunityReportResponse | None
class CommunityReportsExtractor:
"""Community reports extractor class definition."""
_llm: ChatLLM
_input_text_key: str
_extraction_prompt: str
_output_formatter_prompt: str
_on_error: ErrorHandlerFn
_max_report_length: int
def __init__(
self,
llm_invoker: ChatLLM,
input_text_key: str | None = None,
extraction_prompt: str | None = None,
on_error: ErrorHandlerFn | None = None,
max_report_length: int | None = None,
):
"""Init method definition."""
self._llm = llm_invoker
self._input_text_key = input_text_key or "input_text"
self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT
self._on_error = on_error or (lambda _e, _s, _d: None)
self._max_report_length = max_report_length or 1500
async def __call__(self, inputs: dict[str, Any]):
"""Call method definition."""
output = None
try:
input_text = inputs[self._input_text_key]
prompt = self._extraction_prompt.replace(
"{" + self._input_text_key + "}", input_text
)
response = await self._llm(
prompt,
json=True,
name="create_community_report",
json_model=CommunityReportResponse,
model_parameters={"max_tokens": self._max_report_length},
)
output = response.parsed_json
except Exception as e:
log.exception("error generating community report")
self._on_error(e, traceback.format_exc(), None)
text_output = self._get_text_output(output) if output else ""
return CommunityReportsResult(
structured_output=output,
output=text_output,
)
def _get_text_output(self, report: CommunityReportResponse) -> str:
report_sections = "\n\n".join(
f"## {f.summary}\n\n{f.explanation}" for f in report.findings
)
return f"# {report.title}\n\n{report.summary}\n\n{report_sections}"

View File

@ -0,0 +1,185 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing create_community_reports and load_strategy methods definition."""
import logging
from typing import cast
import pandas as pd
import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas
from graphrag.index.operations.summarize_communities.community_reports_extractor.build_mixed_context import (
build_mixed_context,
)
from graphrag.index.operations.summarize_communities.community_reports_extractor.sort_context import (
sort_context,
)
from graphrag.index.utils.dataframes import (
antijoin,
drop_columns,
join,
select,
transform_series,
union,
where_column_equals,
)
from graphrag.query.llm.text_utils import num_tokens
log = logging.getLogger(__name__)
def prep_community_report_context(
report_df: pd.DataFrame | None,
community_hierarchy_df: pd.DataFrame,
local_context_df: pd.DataFrame,
level: int,
max_tokens: int,
) -> pd.DataFrame:
"""
Prep context for each community in a given level.
For each community:
- Check if local context fits within the limit, if yes use local context
- If local context exceeds the limit, iteratively replace local context with sub-community reports, starting from the biggest sub-community
"""
if report_df is None:
report_df = pd.DataFrame()
# Filter by community level
level_context_df = local_context_df.loc[
local_context_df.loc[:, schemas.COMMUNITY_LEVEL] == level
]
# Filter valid and invalid contexts using boolean logic
valid_context_df = level_context_df.loc[
~level_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG]
]
invalid_context_df = level_context_df.loc[
level_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG]
]
# there is no report to substitute with, so we just trim the local context of the invalid context records
# this case should only happen at the bottom level of the community hierarchy where there are no sub-communities
if invalid_context_df.empty:
return valid_context_df
if report_df.empty:
invalid_context_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
invalid_context_df, max_tokens
)
invalid_context_df[schemas.CONTEXT_SIZE] = invalid_context_df.loc[
:, schemas.CONTEXT_STRING
].map(num_tokens)
invalid_context_df[schemas.CONTEXT_EXCEED_FLAG] = 0
return union(valid_context_df, invalid_context_df)
level_context_df = _antijoin_reports(level_context_df, report_df)
# for each invalid context, we will try to substitute with sub-community reports
# first get local context and report (if available) for each sub-community
sub_context_df = _get_subcontext_df(level + 1, report_df, local_context_df)
community_df = _get_community_df(
level, invalid_context_df, sub_context_df, community_hierarchy_df, max_tokens
)
# handle any remaining invalid records that can't be subsituted with sub-community reports
# this should be rare, but if it happens, we will just trim the local context to fit the limit
remaining_df = _antijoin_reports(invalid_context_df, community_df)
remaining_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
remaining_df, max_tokens
)
result = union(valid_context_df, community_df, remaining_df)
result[schemas.CONTEXT_SIZE] = result.loc[:, schemas.CONTEXT_STRING].map(num_tokens)
result[schemas.CONTEXT_EXCEED_FLAG] = 0
return result
def _drop_community_level(df: pd.DataFrame) -> pd.DataFrame:
"""Drop the community level column from the dataframe."""
return drop_columns(df, schemas.COMMUNITY_LEVEL)
def _at_level(level: int, df: pd.DataFrame) -> pd.DataFrame:
"""Return records at the given level."""
return where_column_equals(df, schemas.COMMUNITY_LEVEL, level)
def _antijoin_reports(df: pd.DataFrame, reports: pd.DataFrame) -> pd.DataFrame:
"""Return records in df that are not in reports."""
return antijoin(df, reports, schemas.NODE_COMMUNITY)
def _sort_and_trim_context(df: pd.DataFrame, max_tokens: int) -> pd.Series:
"""Sort and trim context to fit the limit."""
series = cast("pd.Series", df[schemas.ALL_CONTEXT])
return transform_series(series, lambda x: sort_context(x, max_tokens=max_tokens))
def _build_mixed_context(df: pd.DataFrame, max_tokens: int) -> pd.Series:
"""Sort and trim context to fit the limit."""
series = cast("pd.Series", df[schemas.ALL_CONTEXT])
return transform_series(
series, lambda x: build_mixed_context(x, max_tokens=max_tokens)
)
def _get_subcontext_df(
level: int, report_df: pd.DataFrame, local_context_df: pd.DataFrame
) -> pd.DataFrame:
"""Get sub-community context for each community."""
sub_report_df = _drop_community_level(_at_level(level, report_df))
sub_context_df = _at_level(level, local_context_df)
sub_context_df = join(sub_context_df, sub_report_df, schemas.NODE_COMMUNITY)
sub_context_df.rename(
columns={schemas.NODE_COMMUNITY: schemas.SUB_COMMUNITY}, inplace=True
)
return sub_context_df
def _get_community_df(
level: int,
invalid_context_df: pd.DataFrame,
sub_context_df: pd.DataFrame,
community_hierarchy_df: pd.DataFrame,
max_tokens: int,
) -> pd.DataFrame:
"""Get community context for each community."""
# collect all sub communities' contexts for each community
community_df = _drop_community_level(_at_level(level, community_hierarchy_df))
invalid_community_ids = select(invalid_context_df, schemas.NODE_COMMUNITY)
subcontext_selection = select(
sub_context_df,
schemas.SUB_COMMUNITY,
schemas.FULL_CONTENT,
schemas.ALL_CONTEXT,
schemas.CONTEXT_SIZE,
)
invalid_communities = join(
community_df, invalid_community_ids, schemas.NODE_COMMUNITY, "inner"
)
community_df = join(
invalid_communities, subcontext_selection, schemas.SUB_COMMUNITY
)
community_df[schemas.ALL_CONTEXT] = community_df.apply(
lambda x: {
schemas.SUB_COMMUNITY: x[schemas.SUB_COMMUNITY],
schemas.ALL_CONTEXT: x[schemas.ALL_CONTEXT],
schemas.FULL_CONTENT: x[schemas.FULL_CONTENT],
schemas.CONTEXT_SIZE: x[schemas.CONTEXT_SIZE],
},
axis=1,
)
community_df = (
community_df.groupby(schemas.NODE_COMMUNITY)
.agg({schemas.ALL_CONTEXT: list})
.reset_index()
)
community_df[schemas.CONTEXT_STRING] = _build_mixed_context(
community_df, max_tokens
)
community_df[schemas.COMMUNITY_LEVEL] = level
return community_df

View File

@ -0,0 +1,52 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Common field name definitions for community reports."""
# POST-PREP NODE TABLE SCHEMA
NODE_ID = "human_readable_id"
NODE_NAME = "title"
NODE_DESCRIPTION = "description"
NODE_DEGREE = "degree"
NODE_DETAILS = "node_details"
NODE_COMMUNITY = "community"
NODE_LEVEL = "level"
# POST-PREP EDGE TABLE SCHEMA
EDGE_ID = "human_readable_id"
EDGE_SOURCE = "source"
EDGE_TARGET = "target"
EDGE_DESCRIPTION = "description"
EDGE_DEGREE = "combined_degree"
EDGE_DETAILS = "edge_details"
EDGE_WEIGHT = "weight"
# POST-PREP CLAIM TABLE SCHEMA
CLAIM_ID = "human_readable_id"
CLAIM_SUBJECT = "subject_id"
CLAIM_TYPE = "type"
CLAIM_STATUS = "status"
CLAIM_DESCRIPTION = "description"
CLAIM_DETAILS = "claim_details"
# COMMUNITY HIERARCHY TABLE SCHEMA
SUB_COMMUNITY = "sub_community"
SUB_COMMUNITY_SIZE = "sub_community_size"
COMMUNITY_LEVEL = "level"
# COMMUNITY CONTEXT TABLE SCHEMA
ALL_CONTEXT = "all_context"
CONTEXT_STRING = "context_string"
CONTEXT_SIZE = "context_size"
CONTEXT_EXCEED_FLAG = "context_exceed_limit"
# COMMUNITY REPORT TABLE SCHEMA
REPORT_ID = "id"
COMMUNITY_ID = "community"
COMMUNITY_LEVEL = "level"
TITLE = "title"
SUMMARY = "summary"
FINDINGS = "findings"
RATING = "rank"
EXPLANATION = "rating_explanation"
FULL_CONTENT = "full_content"
FULL_CONTENT_JSON = "full_content_json"

View File

@ -0,0 +1,154 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Sort context by degree in descending order."""
import pandas as pd
import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas
from graphrag.query.llm.text_utils import num_tokens
def sort_context(
local_context: list[dict],
sub_community_reports: list[dict] | None = None,
max_tokens: int | None = None,
node_name_column: str = schemas.NODE_NAME,
node_details_column: str = schemas.NODE_DETAILS,
edge_id_column: str = schemas.EDGE_ID,
edge_details_column: str = schemas.EDGE_DETAILS,
edge_degree_column: str = schemas.EDGE_DEGREE,
edge_source_column: str = schemas.EDGE_SOURCE,
edge_target_column: str = schemas.EDGE_TARGET,
claim_details_column: str = schemas.CLAIM_DETAILS,
) -> str:
"""Sort context by degree in descending order, optimizing for performance."""
def _get_context_string(
entities: list[dict],
edges: list[dict],
claims: list[dict],
sub_community_reports: list[dict] | None = None,
) -> str:
"""Concatenate structured data into a context string."""
contexts = []
if sub_community_reports:
report_df = pd.DataFrame(sub_community_reports)
if not report_df.empty:
contexts.append(
f"----Reports-----\n{report_df.to_csv(index=False, sep=',')}"
)
for label, data in [
("Entities", entities),
("Claims", claims),
("Relationships", edges),
]:
if data:
data_df = pd.DataFrame(data)
if not data_df.empty:
contexts.append(
f"-----{label}-----\n{data_df.to_csv(index=False, sep=',')}"
)
return "\n\n".join(contexts)
# Preprocess local context
edges = [
{**e, schemas.EDGE_ID: int(e[schemas.EDGE_ID])}
for record in local_context
for e in record.get(edge_details_column, [])
if isinstance(e, dict)
]
node_details = {
record[node_name_column]: {
**record[node_details_column],
schemas.NODE_ID: int(record[node_details_column][schemas.NODE_ID]),
}
for record in local_context
}
claim_details = {
record[node_name_column]: [
{**c, schemas.CLAIM_ID: int(c[schemas.CLAIM_ID])}
for c in record.get(claim_details_column, [])
if isinstance(c, dict) and c.get(schemas.CLAIM_ID) is not None
]
for record in local_context
if isinstance(record.get(claim_details_column), list)
}
# Sort edges by degree (desc) and ID (asc)
edges.sort(key=lambda x: (-x.get(edge_degree_column, 0), x.get(edge_id_column, "")))
# Deduplicate and build context incrementally
edge_ids, nodes_ids, claims_ids = set(), set(), set()
sorted_edges, sorted_nodes, sorted_claims = [], [], []
context_string = ""
for edge in edges:
source, target = edge[edge_source_column], edge[edge_target_column]
# Add source and target node details
for node in [node_details.get(source), node_details.get(target)]:
if node and node[schemas.NODE_ID] not in nodes_ids:
nodes_ids.add(node[schemas.NODE_ID])
sorted_nodes.append(node)
# Add claims related to source and target
for claims in [claim_details.get(source), claim_details.get(target)]:
if claims:
for claim in claims:
if claim[schemas.CLAIM_ID] not in claims_ids:
claims_ids.add(claim[schemas.CLAIM_ID])
sorted_claims.append(claim)
# Add the edge
if edge[schemas.EDGE_ID] not in edge_ids:
edge_ids.add(edge[schemas.EDGE_ID])
sorted_edges.append(edge)
# Generate new context string
new_context_string = _get_context_string(
sorted_nodes, sorted_edges, sorted_claims, sub_community_reports
)
if max_tokens and num_tokens(new_context_string) > max_tokens:
break
context_string = new_context_string
# Return the final context string
return context_string or _get_context_string(
sorted_nodes, sorted_edges, sorted_claims, sub_community_reports
)
def parallel_sort_context_batch(community_df, max_tokens, parallel=False):
"""Calculate context using parallelization if enabled."""
if parallel:
# Use ThreadPoolExecutor for parallel execution
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=None) as executor:
context_strings = list(
executor.map(
lambda x: sort_context(x, max_tokens=max_tokens),
community_df[schemas.ALL_CONTEXT],
)
)
community_df[schemas.CONTEXT_STRING] = context_strings
else:
# Assign context strings directly to the DataFrame
community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply(
lambda context_list: sort_context(context_list, max_tokens=max_tokens)
)
# Calculate other columns
community_df[schemas.CONTEXT_SIZE] = community_df[schemas.CONTEXT_STRING].apply(
num_tokens
)
community_df[schemas.CONTEXT_EXCEED_FLAG] = (
community_df[schemas.CONTEXT_SIZE] > max_tokens
)
return community_df

View File

@ -0,0 +1,15 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing community report generation utilities."""
import pandas as pd
import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas
def get_levels(df: pd.DataFrame, level_column: str = schemas.NODE_LEVEL) -> list[int]:
"""Get the levels of the communities."""
levels = df[level_column].dropna().unique()
levels = [int(lvl) for lvl in levels if lvl != -1]
return sorted(levels, reverse=True)

View File

@ -0,0 +1,140 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'GraphExtractionResult' and 'GraphExtractor' models."""
import json
from dataclasses import dataclass
from fnllm import ChatLLM
from graphrag.index.typing import ErrorHandlerFn
from graphrag.index.utils.tokens import num_tokens_from_string
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
# Max token size for input prompts
DEFAULT_MAX_INPUT_TOKENS = 4_000
# Max token count for LLM answers
DEFAULT_MAX_SUMMARY_LENGTH = 500
@dataclass
class SummarizationResult:
"""Unipartite graph extraction result class definition."""
id: str | tuple[str, str]
description: str
class SummarizeExtractor:
"""Unipartite graph extractor class definition."""
_llm: ChatLLM
_entity_name_key: str
_input_descriptions_key: str
_summarization_prompt: str
_on_error: ErrorHandlerFn
_max_summary_length: int
_max_input_tokens: int
def __init__(
self,
llm_invoker: ChatLLM,
entity_name_key: str | None = None,
input_descriptions_key: str | None = None,
summarization_prompt: str | None = None,
on_error: ErrorHandlerFn | None = None,
max_summary_length: int | None = None,
max_input_tokens: int | None = None,
):
"""Init method definition."""
# TODO: streamline construction
self._llm = llm_invoker
self._entity_name_key = entity_name_key or "entity_name"
self._input_descriptions_key = input_descriptions_key or "description_list"
self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT
self._on_error = on_error or (lambda _e, _s, _d: None)
self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH
self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS
async def __call__(
self,
id: str | tuple[str, str],
descriptions: list[str],
) -> SummarizationResult:
"""Call method definition."""
result = ""
if len(descriptions) == 0:
result = ""
elif len(descriptions) == 1:
result = descriptions[0]
else:
result = await self._summarize_descriptions(id, descriptions)
return SummarizationResult(
id=id,
description=result or "",
)
async def _summarize_descriptions(
self, id: str | tuple[str, str], descriptions: list[str]
) -> str:
"""Summarize descriptions into a single description."""
sorted_id = sorted(id) if isinstance(id, list) else id
# Safety check, should always be a list
if not isinstance(descriptions, list):
descriptions = [descriptions]
# Sort description lists
if len(descriptions) > 1:
descriptions = sorted(descriptions)
# Iterate over descriptions, adding all until the max input tokens is reached
usable_tokens = self._max_input_tokens - num_tokens_from_string(
self._summarization_prompt
)
descriptions_collected = []
result = ""
for i, description in enumerate(descriptions):
usable_tokens -= num_tokens_from_string(description)
descriptions_collected.append(description)
# If buffer is full, or all descriptions have been added, summarize
if (usable_tokens < 0 and len(descriptions_collected) > 1) or (
i == len(descriptions) - 1
):
# Calculate result (final or partial)
result = await self._summarize_descriptions_with_llm(
sorted_id, descriptions_collected
)
# If we go for another loop, reset values to new
if i != len(descriptions) - 1:
descriptions_collected = [result]
usable_tokens = (
self._max_input_tokens
- num_tokens_from_string(self._summarization_prompt)
- num_tokens_from_string(result)
)
return result
async def _summarize_descriptions_with_llm(
self, id: str | tuple[str, str] | list[str], descriptions: list[str]
):
"""Summarize descriptions using the LLM."""
response = await self._llm(
self._summarization_prompt.format(**{
self._entity_name_key: json.dumps(id, ensure_ascii=False),
self._input_descriptions_key: json.dumps(
sorted(descriptions), ensure_ascii=False
),
}),
name="summarize",
model_parameters={"max_tokens": self._max_summary_length},
)
# Calculate result
return str(response.output.content)

View File

@ -0,0 +1,64 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing run_graph_intelligence, run_resolve_entities and _create_text_list_splitter methods to run graph intelligence."""
from datashaper import VerbCallbacks
from fnllm import ChatLLM
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.llm.load_llm import load_llm, read_llm_params
from graphrag.index.operations.summarize_descriptions.description_summary_extractor import (
SummarizeExtractor,
)
from graphrag.index.operations.summarize_descriptions.typing import (
StrategyConfig,
SummarizedDescriptionResult,
)
async def run_graph_intelligence(
id: str | tuple[str, str],
descriptions: list[str],
callbacks: VerbCallbacks,
cache: PipelineCache,
args: StrategyConfig,
) -> SummarizedDescriptionResult:
"""Run the graph intelligence entity extraction strategy."""
llm_config = read_llm_params(args.get("llm", {}))
llm = load_llm(
"summarize_descriptions", llm_config, callbacks=callbacks, cache=cache
)
return await run_summarize_descriptions(llm, id, descriptions, callbacks, args)
async def run_summarize_descriptions(
llm: ChatLLM,
id: str | tuple[str, str],
descriptions: list[str],
callbacks: VerbCallbacks,
args: StrategyConfig,
) -> SummarizedDescriptionResult:
"""Run the entity extraction chain."""
# Extraction Arguments
summarize_prompt = args.get("summarize_prompt", None)
entity_name_key = args.get("entity_name_key", "entity_name")
input_descriptions_key = args.get("input_descriptions_key", "description_list")
max_tokens = args.get("max_tokens", None)
extractor = SummarizeExtractor(
llm_invoker=llm,
summarization_prompt=summarize_prompt,
entity_name_key=entity_name_key,
input_descriptions_key=input_descriptions_key,
on_error=lambda e, stack, details: (
callbacks.error("Entity Extraction Error", e, stack, details)
if callbacks
else None
),
max_summary_length=args.get("max_summary_length", None),
max_input_tokens=max_tokens,
)
result = await extractor(id=id, descriptions=descriptions)
return SummarizedDescriptionResult(id=result.id, description=result.description)

View File

@ -0,0 +1,67 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module for producing a stable largest connected component, i.e. same input graph == same output lcc."""
import html
from typing import Any, cast
import networkx as nx
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
"""Return the largest connected component of the graph, with nodes and edges sorted in a stable way."""
# NOTE: The import is done here to reduce the initial import time of the module
from graspologic.utils import largest_connected_component
graph = graph.copy()
graph = cast("nx.Graph", largest_connected_component(graph))
graph = normalize_node_names(graph)
return _stabilize_graph(graph)
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
"""Ensure an undirected graph with the same relationships will always be read the same way."""
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
sorted_nodes = graph.nodes(data=True)
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
fixed_graph.add_nodes_from(sorted_nodes)
edges = list(graph.edges(data=True))
# If the graph is undirected, we create the edges in a stable way, so we get the same results
# for example:
# A -> B
# in graph theory is the same as
# B -> A
# in an undirected graph
# however, this can lead to downstream issues because sometimes
# consumers read graph.nodes() which ends up being [A, B] and sometimes it's [B, A]
# but they base some of their logic on the order of the nodes, so the order ends up being important
# so we sort the nodes in the edge in a stable way, so that we always get the same order
if not graph.is_directed():
def _sort_source_target(edge):
source, target, edge_data = edge
if source > target:
temp = source
source = target
target = temp
return source, target, edge_data
edges = [_sort_source_target(edge) for edge in edges]
def _get_edge_key(source: Any, target: Any) -> str:
return f"{source} -> {target}"
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
fixed_graph.add_edges_from(edges)
return fixed_graph
def normalize_node_names(graph: nx.Graph | nx.DiGraph) -> nx.Graph | nx.DiGraph:
"""Normalize node names."""
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
return nx.relabel_nodes(graph, node_mapping)

2
graphrag/py.typed Normal file
View File

@ -0,0 +1,2 @@
# This package supports type hinting,
# see https://www.python.org/dev/peps/pep-0561/#packaging-type-information

View File

@ -0,0 +1,109 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""CosmosDB Storage Tests."""
import json
import re
import sys
import pytest
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
# cspell:disable-next-line well-known-key
WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
# the cosmosdb emulator is only available on windows runners at this time
if not sys.platform.startswith("win"):
pytest.skip(
"encountered windows-only tests -- will skip for now", allow_module_level=True
)
async def test_find():
storage = CosmosDBPipelineStorage(
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
database_name="testfind",
container_name="testfindcontainer",
)
try:
try:
items = list(storage.find(file_pattern=re.compile(r".*\.json$")))
items = [item[0] for item in items]
assert items == []
json_content = {
"content": "Merry Christmas!",
}
await storage.set(
"christmas.json", json.dumps(json_content), encoding="utf-8"
)
items = list(storage.find(file_pattern=re.compile(r".*\.json$")))
items = [item[0] for item in items]
assert items == ["christmas.json"]
json_content = {
"content": "Hello, World!",
}
await storage.set("test.json", json.dumps(json_content), encoding="utf-8")
items = list(storage.find(file_pattern=re.compile(r".*\.json$")))
items = [item[0] for item in items]
assert items == ["christmas.json", "test.json"]
output = await storage.get("test.json")
output_json = json.loads(output)
assert output_json["content"] == "Hello, World!"
json_exists = await storage.has("christmas.json")
assert json_exists is True
json_exists = await storage.has("easter.json")
assert json_exists is False
finally:
await storage.delete("test.json")
output = await storage.get("test.json")
assert output is None
finally:
await storage.clear()
async def test_child():
storage = CosmosDBPipelineStorage(
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
database_name="testchild",
container_name="testchildcontainer",
)
try:
child_storage = storage.child("child")
assert type(child_storage) is CosmosDBPipelineStorage
finally:
await storage.clear()
async def test_clear():
storage = CosmosDBPipelineStorage(
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
database_name="testclear",
container_name="testclearcontainer",
)
try:
json_exists = {
"content": "Merry Christmas!",
}
await storage.set("christmas.json", json.dumps(json_exists), encoding="utf-8")
json_exists = {
"content": "Happy Easter!",
}
await storage.set("easter.json", json.dumps(json_exists), encoding="utf-8")
await storage.clear()
items = list(storage.find(file_pattern=re.compile(r".*\.json$")))
items = [item[0] for item in items]
assert items == []
output = await storage.get("easter.json")
assert output is None
assert storage._container_client is None # noqa: SLF001
assert storage._database_client is None # noqa: SLF001
finally:
await storage.clear()