diff --git a/.semversioner/1.0.1.json b/.semversioner/1.0.1.json new file mode 100644 index 00000000..e9c38b7b --- /dev/null +++ b/.semversioner/1.0.1.json @@ -0,0 +1,22 @@ +{ + "changes": [ + { + "description": "Fix encoding model config parsing", + "type": "patch" + }, + { + "description": "Fix exception on error callbacks", + "type": "patch" + }, + { + "description": "Manage llm instances inside a cached singleton. Check for empty dfs after entity/relationship extraction", + "type": "patch" + }, + { + "description": "Respect encoding_model option", + "type": "patch" + } + ], + "created_at": "2024-12-18T23:12:52+00:00", + "version": "1.0.1" +} \ No newline at end of file diff --git a/.semversioner/next-release/patch-20241121202210026640.json b/.semversioner/next-release/patch-20241121202210026640.json new file mode 100644 index 00000000..ac65aaf8 --- /dev/null +++ b/.semversioner/next-release/patch-20241121202210026640.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Implement cosmosdb storage option for cache and output" +} diff --git a/.semversioner/next-release/patch-20241212190223784600.json b/.semversioner/next-release/patch-20241212190223784600.json new file mode 100644 index 00000000..d54d621c --- /dev/null +++ b/.semversioner/next-release/patch-20241212190223784600.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Streamline flows." +} diff --git a/.semversioner/next-release/patch-20241213181544864279.json b/.semversioner/next-release/patch-20241213181544864279.json new file mode 100644 index 00000000..17361b58 --- /dev/null +++ b/.semversioner/next-release/patch-20241213181544864279.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Move extractor code to co-locate with operations." +} diff --git a/graphrag/index/llm/manager.py b/graphrag/index/llm/manager.py new file mode 100644 index 00000000..2e35cd89 --- /dev/null +++ b/graphrag/index/llm/manager.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Manager singleton.""" + +from functools import cache + +from fnllm import ChatLLM, EmbeddingsLLM + + +@cache +class ChatLLMSingleton: + """A singleton class for the chat LLM instances.""" + + def __init__(self): + self.llm_dict = {} + + def set_llm(self, name, llm): + """Add an LLM to the dictionary.""" + self.llm_dict[name] = llm + + def get_llm(self, name) -> ChatLLM | None: + """Get an LLM from the dictionary.""" + return self.llm_dict.get(name) + + +@cache +class EmbeddingsLLMSingleton: + """A singleton class for the embeddings LLM instances.""" + + def __init__(self): + self.llm_dict = {} + + def set_llm(self, name, llm): + """Add an LLM to the dictionary.""" + self.llm_dict[name] = llm + + def get_llm(self, name) -> EmbeddingsLLM | None: + """Get an LLM from the dictionary.""" + return self.llm_dict.get(name) diff --git a/graphrag/index/operations/compute_degree.py b/graphrag/index/operations/compute_degree.py new file mode 100644 index 00000000..b720bf6d --- /dev/null +++ b/graphrag/index/operations/compute_degree.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_graph definition.""" + +import networkx as nx +import pandas as pd + + +def compute_degree(graph: nx.Graph) -> pd.DataFrame: + """Create a new DataFrame with the degree of each node in the graph.""" + return pd.DataFrame([ + {"title": node, "degree": int(degree)} + for node, degree in graph.degree # type: ignore + ]) diff --git a/graphrag/index/operations/embed_graph/embed_node2vec.py b/graphrag/index/operations/embed_graph/embed_node2vec.py new file mode 100644 index 00000000..a009c670 --- /dev/null +++ b/graphrag/index/operations/embed_graph/embed_node2vec.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Utilities to generate graph embeddings.""" + +from dataclasses import dataclass + +import networkx as nx +import numpy as np + + +@dataclass +class NodeEmbeddings: + """Node embeddings class definition.""" + + nodes: list[str] + embeddings: np.ndarray + + +def embed_node2vec( + graph: nx.Graph | nx.DiGraph, + dimensions: int = 1536, + num_walks: int = 10, + walk_length: int = 40, + window_size: int = 2, + iterations: int = 3, + random_seed: int = 86, +) -> NodeEmbeddings: + """Generate node embeddings using Node2Vec.""" + # NOTE: This import is done here to reduce the initial import time of the graphrag package + import graspologic as gc + + # generate embedding + lcc_tensors = gc.embed.node2vec_embed( # type: ignore + graph=graph, + dimensions=dimensions, + window_size=window_size, + iterations=iterations, + num_walks=num_walks, + walk_length=walk_length, + random_seed=random_seed, + ) + return NodeEmbeddings(embeddings=lcc_tensors[0], nodes=lcc_tensors[1]) diff --git a/graphrag/index/operations/extract_covariates/claim_extractor.py b/graphrag/index/operations/extract_covariates/claim_extractor.py new file mode 100644 index 00000000..66162f8f --- /dev/null +++ b/graphrag/index/operations/extract_covariates/claim_extractor.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'ClaimExtractorResult' and 'ClaimExtractor' models.""" + +import logging +import traceback +from dataclasses import dataclass +from typing import Any + +import tiktoken +from fnllm import ChatLLM + +import graphrag.config.defaults as defs +from graphrag.index.typing import ErrorHandlerFn +from graphrag.prompts.index.claim_extraction import ( + CLAIM_EXTRACTION_PROMPT, + CONTINUE_PROMPT, + LOOP_PROMPT, +) + +DEFAULT_TUPLE_DELIMITER = "<|>" +DEFAULT_RECORD_DELIMITER = "##" +DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>" +log = logging.getLogger(__name__) + + +@dataclass +class ClaimExtractorResult: + """Claim extractor result class definition.""" + + output: list[dict] + source_docs: dict[str, Any] + + +class ClaimExtractor: + """Claim extractor class definition.""" + + _llm: ChatLLM + _extraction_prompt: str + _summary_prompt: str + _output_formatter_prompt: str + _input_text_key: str + _input_entity_spec_key: str + _input_claim_description_key: str + _tuple_delimiter_key: str + _record_delimiter_key: str + _completion_delimiter_key: str + _max_gleanings: int + _on_error: ErrorHandlerFn + _loop_args: dict[str, Any] + + def __init__( + self, + llm_invoker: ChatLLM, + extraction_prompt: str | None = None, + input_text_key: str | None = None, + input_entity_spec_key: str | None = None, + input_claim_description_key: str | None = None, + input_resolved_entities_key: str | None = None, + tuple_delimiter_key: str | None = None, + record_delimiter_key: str | None = None, + completion_delimiter_key: str | None = None, + encoding_model: str | None = None, + max_gleanings: int | None = None, + on_error: ErrorHandlerFn | None = None, + ): + """Init method definition.""" + self._llm = llm_invoker + self._extraction_prompt = extraction_prompt or CLAIM_EXTRACTION_PROMPT + self._input_text_key = input_text_key or "input_text" + self._input_entity_spec_key = input_entity_spec_key or "entity_specs" + self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter" + self._record_delimiter_key = record_delimiter_key or "record_delimiter" + self._completion_delimiter_key = ( + completion_delimiter_key or "completion_delimiter" + ) + self._input_claim_description_key = ( + input_claim_description_key or "claim_description" + ) + self._input_resolved_entities_key = ( + input_resolved_entities_key or "resolved_entities" + ) + self._max_gleanings = ( + max_gleanings if max_gleanings is not None else defs.CLAIM_MAX_GLEANINGS + ) + self._on_error = on_error or (lambda _e, _s, _d: None) + + # Construct the looping arguments + encoding = tiktoken.get_encoding(encoding_model or "cl100k_base") + yes = f"{encoding.encode('YES')[0]}" + no = f"{encoding.encode('NO')[0]}" + self._loop_args = {"logit_bias": {yes: 100, no: 100}, "max_tokens": 1} + + async def __call__( + self, inputs: dict[str, Any], prompt_variables: dict | None = None + ) -> ClaimExtractorResult: + """Call method definition.""" + if prompt_variables is None: + prompt_variables = {} + texts = inputs[self._input_text_key] + entity_spec = str(inputs[self._input_entity_spec_key]) + claim_description = inputs[self._input_claim_description_key] + resolved_entities = inputs.get(self._input_resolved_entities_key, {}) + source_doc_map = {} + + prompt_args = { + self._input_entity_spec_key: entity_spec, + self._input_claim_description_key: claim_description, + self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key) + or DEFAULT_TUPLE_DELIMITER, + self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) + or DEFAULT_RECORD_DELIMITER, + self._completion_delimiter_key: prompt_variables.get( + self._completion_delimiter_key + ) + or DEFAULT_COMPLETION_DELIMITER, + } + + all_claims: list[dict] = [] + for doc_index, text in enumerate(texts): + document_id = f"d{doc_index}" + try: + claims = await self._process_document(prompt_args, text, doc_index) + all_claims += [ + self._clean_claim(c, document_id, resolved_entities) for c in claims + ] + source_doc_map[document_id] = text + except Exception as e: + log.exception("error extracting claim") + self._on_error( + e, + traceback.format_exc(), + {"doc_index": doc_index, "text": text}, + ) + continue + + return ClaimExtractorResult( + output=all_claims, + source_docs=source_doc_map, + ) + + def _clean_claim( + self, claim: dict, document_id: str, resolved_entities: dict + ) -> dict: + # clean the parsed claims to remove any claims with status = False + obj = claim.get("object_id", claim.get("object")) + subject = claim.get("subject_id", claim.get("subject")) + + # If subject or object in resolved entities, then replace with resolved entity + obj = resolved_entities.get(obj, obj) + subject = resolved_entities.get(subject, subject) + claim["object_id"] = obj + claim["subject_id"] = subject + return claim + + async def _process_document( + self, prompt_args: dict, doc, doc_index: int + ) -> list[dict]: + record_delimiter = prompt_args.get( + self._record_delimiter_key, DEFAULT_RECORD_DELIMITER + ) + completion_delimiter = prompt_args.get( + self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER + ) + + response = await self._llm( + self._extraction_prompt.format(**{ + self._input_text_key: doc, + **prompt_args, + }) + ) + results = response.output.content or "" + claims = results.strip().removesuffix(completion_delimiter) + + # Repeat to ensure we maximize entity count + for i in range(self._max_gleanings): + response = await self._llm( + CONTINUE_PROMPT, + name=f"extract-continuation-{i}", + history=response.history, + ) + extension = response.output.content or "" + claims += record_delimiter + extension.strip().removesuffix( + completion_delimiter + ) + + # If this isn't the last loop, check to see if we should continue + if i >= self._max_gleanings - 1: + break + + response = await self._llm( + LOOP_PROMPT, + name=f"extract-loopcheck-{i}", + history=response.history, + model_parameters=self._loop_args, + ) + if response.output.content != "YES": + break + + return self._parse_claim_tuples(results, prompt_args) + + def _parse_claim_tuples( + self, claims: str, prompt_variables: dict + ) -> list[dict[str, Any]]: + """Parse claim tuples.""" + record_delimiter = prompt_variables.get( + self._record_delimiter_key, DEFAULT_RECORD_DELIMITER + ) + completion_delimiter = prompt_variables.get( + self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER + ) + tuple_delimiter = prompt_variables.get( + self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER + ) + + def pull_field(index: int, fields: list[str]) -> str | None: + return fields[index].strip() if len(fields) > index else None + + result: list[dict[str, Any]] = [] + claims_values = ( + claims.strip().removesuffix(completion_delimiter).split(record_delimiter) + ) + for claim in claims_values: + claim = claim.strip().removeprefix("(").removesuffix(")") + + # Ignore the completion delimiter + if claim == completion_delimiter: + continue + + claim_fields = claim.split(tuple_delimiter) + result.append({ + "subject_id": pull_field(0, claim_fields), + "object_id": pull_field(1, claim_fields), + "type": pull_field(2, claim_fields), + "status": pull_field(3, claim_fields), + "start_date": pull_field(4, claim_fields), + "end_date": pull_field(5, claim_fields), + "description": pull_field(6, claim_fields), + "source_text": pull_field(7, claim_fields), + }) + return result diff --git a/graphrag/index/operations/extract_entities/graph_extractor.py b/graphrag/index/operations/extract_entities/graph_extractor.py new file mode 100644 index 00000000..26df56a9 --- /dev/null +++ b/graphrag/index/operations/extract_entities/graph_extractor.py @@ -0,0 +1,308 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'GraphExtractionResult' and 'GraphExtractor' models.""" + +import logging +import re +import traceback +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +import networkx as nx +import tiktoken +from fnllm import ChatLLM + +import graphrag.config.defaults as defs +from graphrag.index.typing import ErrorHandlerFn +from graphrag.index.utils.string import clean_str +from graphrag.prompts.index.entity_extraction import ( + CONTINUE_PROMPT, + GRAPH_EXTRACTION_PROMPT, + LOOP_PROMPT, +) + +DEFAULT_TUPLE_DELIMITER = "<|>" +DEFAULT_RECORD_DELIMITER = "##" +DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>" +DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"] + +log = logging.getLogger(__name__) + + +@dataclass +class GraphExtractionResult: + """Unipartite graph extraction result class definition.""" + + output: nx.Graph + source_docs: dict[Any, Any] + + +class GraphExtractor: + """Unipartite graph extractor class definition.""" + + _llm: ChatLLM + _join_descriptions: bool + _tuple_delimiter_key: str + _record_delimiter_key: str + _entity_types_key: str + _input_text_key: str + _completion_delimiter_key: str + _entity_name_key: str + _input_descriptions_key: str + _extraction_prompt: str + _summarization_prompt: str + _loop_args: dict[str, Any] + _max_gleanings: int + _on_error: ErrorHandlerFn + + def __init__( + self, + llm_invoker: ChatLLM, + tuple_delimiter_key: str | None = None, + record_delimiter_key: str | None = None, + input_text_key: str | None = None, + entity_types_key: str | None = None, + completion_delimiter_key: str | None = None, + prompt: str | None = None, + join_descriptions=True, + encoding_model: str | None = None, + max_gleanings: int | None = None, + on_error: ErrorHandlerFn | None = None, + ): + """Init method definition.""" + # TODO: streamline construction + self._llm = llm_invoker + self._join_descriptions = join_descriptions + self._input_text_key = input_text_key or "input_text" + self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter" + self._record_delimiter_key = record_delimiter_key or "record_delimiter" + self._completion_delimiter_key = ( + completion_delimiter_key or "completion_delimiter" + ) + self._entity_types_key = entity_types_key or "entity_types" + self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT + self._max_gleanings = ( + max_gleanings + if max_gleanings is not None + else defs.ENTITY_EXTRACTION_MAX_GLEANINGS + ) + self._on_error = on_error or (lambda _e, _s, _d: None) + + # Construct the looping arguments + encoding = tiktoken.get_encoding(encoding_model or "cl100k_base") + yes = f"{encoding.encode('YES')[0]}" + no = f"{encoding.encode('NO')[0]}" + self._loop_args = {"logit_bias": {yes: 100, no: 100}, "max_tokens": 1} + + async def __call__( + self, texts: list[str], prompt_variables: dict[str, Any] | None = None + ) -> GraphExtractionResult: + """Call method definition.""" + if prompt_variables is None: + prompt_variables = {} + all_records: dict[int, str] = {} + source_doc_map: dict[int, str] = {} + + # Wire defaults into the prompt variables + prompt_variables = { + **prompt_variables, + self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key) + or DEFAULT_TUPLE_DELIMITER, + self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) + or DEFAULT_RECORD_DELIMITER, + self._completion_delimiter_key: prompt_variables.get( + self._completion_delimiter_key + ) + or DEFAULT_COMPLETION_DELIMITER, + self._entity_types_key: ",".join( + prompt_variables[self._entity_types_key] or DEFAULT_ENTITY_TYPES + ), + } + + for doc_index, text in enumerate(texts): + try: + # Invoke the entity extraction + result = await self._process_document(text, prompt_variables) + source_doc_map[doc_index] = text + all_records[doc_index] = result + except Exception as e: + log.exception("error extracting graph") + self._on_error( + e, + traceback.format_exc(), + { + "doc_index": doc_index, + "text": text, + }, + ) + + output = await self._process_results( + all_records, + prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER), + prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER), + ) + + return GraphExtractionResult( + output=output, + source_docs=source_doc_map, + ) + + async def _process_document( + self, text: str, prompt_variables: dict[str, str] + ) -> str: + response = await self._llm( + self._extraction_prompt.format(**{ + **prompt_variables, + self._input_text_key: text, + }), + ) + results = response.output.content or "" + + # Repeat to ensure we maximize entity count + for i in range(self._max_gleanings): + response = await self._llm( + CONTINUE_PROMPT, + name=f"extract-continuation-{i}", + history=response.history, + ) + results += response.output.content or "" + + # if this is the final glean, don't bother updating the continuation flag + if i >= self._max_gleanings - 1: + break + + response = await self._llm( + LOOP_PROMPT, + name=f"extract-loopcheck-{i}", + history=response.history, + model_parameters=self._loop_args, + ) + if response.output != "YES": + break + + return results + + async def _process_results( + self, + results: dict[int, str], + tuple_delimiter: str, + record_delimiter: str, + ) -> nx.Graph: + """Parse the result string to create an undirected unipartite graph. + + Args: + - results - dict of results from the extraction chain + - tuple_delimiter - delimiter between tuples in an output record, default is '<|>' + - record_delimiter - delimiter between records, default is '##' + Returns: + - output - unipartite graph in graphML format + """ + graph = nx.Graph() + for source_doc_id, extracted_data in results.items(): + records = [r.strip() for r in extracted_data.split(record_delimiter)] + + for record in records: + record = re.sub(r"^\(|\)$", "", record.strip()) + record_attributes = record.split(tuple_delimiter) + + if record_attributes[0] == '"entity"' and len(record_attributes) >= 4: + # add this record as a node in the G + entity_name = clean_str(record_attributes[1].upper()) + entity_type = clean_str(record_attributes[2].upper()) + entity_description = clean_str(record_attributes[3]) + + if entity_name in graph.nodes(): + node = graph.nodes[entity_name] + if self._join_descriptions: + node["description"] = "\n".join( + list({ + *_unpack_descriptions(node), + entity_description, + }) + ) + else: + if len(entity_description) > len(node["description"]): + node["description"] = entity_description + node["source_id"] = ", ".join( + list({ + *_unpack_source_ids(node), + str(source_doc_id), + }) + ) + node["type"] = ( + entity_type if entity_type != "" else node["type"] + ) + else: + graph.add_node( + entity_name, + type=entity_type, + description=entity_description, + source_id=str(source_doc_id), + ) + + if ( + record_attributes[0] == '"relationship"' + and len(record_attributes) >= 5 + ): + # add this record as edge + source = clean_str(record_attributes[1].upper()) + target = clean_str(record_attributes[2].upper()) + edge_description = clean_str(record_attributes[3]) + edge_source_id = clean_str(str(source_doc_id)) + try: + weight = float(record_attributes[-1]) + except ValueError: + weight = 1.0 + + if source not in graph.nodes(): + graph.add_node( + source, + type="", + description="", + source_id=edge_source_id, + ) + if target not in graph.nodes(): + graph.add_node( + target, + type="", + description="", + source_id=edge_source_id, + ) + if graph.has_edge(source, target): + edge_data = graph.get_edge_data(source, target) + if edge_data is not None: + weight += edge_data["weight"] + if self._join_descriptions: + edge_description = "\n".join( + list({ + *_unpack_descriptions(edge_data), + edge_description, + }) + ) + edge_source_id = ", ".join( + list({ + *_unpack_source_ids(edge_data), + str(source_doc_id), + }) + ) + graph.add_edge( + source, + target, + weight=weight, + description=edge_description, + source_id=edge_source_id, + ) + + return graph + + +def _unpack_descriptions(data: Mapping) -> list[str]: + value = data.get("description", None) + return [] if value is None else value.split("\n") + + +def _unpack_source_ids(data: Mapping) -> list[str]: + value = data.get("source_id", None) + return [] if value is None else value.split(", ") diff --git a/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py b/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py new file mode 100644 index 00000000..a91e0748 --- /dev/null +++ b/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py @@ -0,0 +1,139 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_graph_intelligence, run_extract_entities and _create_text_splitter methods to run graph intelligence.""" + +import networkx as nx +from datashaper import VerbCallbacks +from fnllm import ChatLLM + +import graphrag.config.defaults as defs +from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.index.llm.load_llm import load_llm, read_llm_params +from graphrag.index.operations.extract_entities.graph_extractor import GraphExtractor +from graphrag.index.operations.extract_entities.typing import ( + Document, + EntityExtractionResult, + EntityTypes, + StrategyConfig, +) +from graphrag.index.text_splitting.text_splitting import ( + NoopTextSplitter, + TextSplitter, + TokenTextSplitter, +) + + +async def run_graph_intelligence( + docs: list[Document], + entity_types: EntityTypes, + callbacks: VerbCallbacks, + cache: PipelineCache, + args: StrategyConfig, +) -> EntityExtractionResult: + """Run the graph intelligence entity extraction strategy.""" + llm_config = read_llm_params(args.get("llm", {})) + llm = load_llm("entity_extraction", llm_config, callbacks=callbacks, cache=cache) + return await run_extract_entities(llm, docs, entity_types, callbacks, args) + + +async def run_extract_entities( + llm: ChatLLM, + docs: list[Document], + entity_types: EntityTypes, + callbacks: VerbCallbacks | None, + args: StrategyConfig, +) -> EntityExtractionResult: + """Run the entity extraction chain.""" + encoding_name = args.get("encoding_name", "cl100k_base") + + # Chunking Arguments + prechunked = args.get("prechunked", False) + chunk_size = args.get("chunk_size", defs.CHUNK_SIZE) + chunk_overlap = args.get("chunk_overlap", defs.CHUNK_OVERLAP) + + # Extraction Arguments + tuple_delimiter = args.get("tuple_delimiter", None) + record_delimiter = args.get("record_delimiter", None) + completion_delimiter = args.get("completion_delimiter", None) + extraction_prompt = args.get("extraction_prompt", None) + encoding_model = args.get("encoding_name", None) + max_gleanings = args.get("max_gleanings", defs.ENTITY_EXTRACTION_MAX_GLEANINGS) + + # note: We're not using UnipartiteGraphChain.from_params + # because we want to pass "timeout" to the llm_kwargs + text_splitter = _create_text_splitter( + prechunked, chunk_size, chunk_overlap, encoding_name + ) + + extractor = GraphExtractor( + llm_invoker=llm, + prompt=extraction_prompt, + encoding_model=encoding_model, + max_gleanings=max_gleanings, + on_error=lambda e, s, d: ( + callbacks.error("Entity Extraction Error", e, s, d) if callbacks else None + ), + ) + text_list = [doc.text.strip() for doc in docs] + + # If it's not pre-chunked, then re-chunk the input + if not prechunked: + text_list = text_splitter.split_text("\n".join(text_list)) + + results = await extractor( + list(text_list), + { + "entity_types": entity_types, + "tuple_delimiter": tuple_delimiter, + "record_delimiter": record_delimiter, + "completion_delimiter": completion_delimiter, + }, + ) + + graph = results.output + # Map the "source_id" back to the "id" field + for _, node in graph.nodes(data=True): # type: ignore + if node is not None: + node["source_id"] = ",".join( + docs[int(id)].id for id in node["source_id"].split(",") + ) + + for _, _, edge in graph.edges(data=True): # type: ignore + if edge is not None: + edge["source_id"] = ",".join( + docs[int(id)].id for id in edge["source_id"].split(",") + ) + + entities = [ + ({"title": item[0], **(item[1] or {})}) + for item in graph.nodes(data=True) + if item is not None + ] + + relationships = nx.to_pandas_edgelist(graph) + + return EntityExtractionResult(entities, relationships, graph) + + +def _create_text_splitter( + prechunked: bool, chunk_size: int, chunk_overlap: int, encoding_name: str +) -> TextSplitter: + """Create a text splitter for the extraction chain. + + Args: + - prechunked - Whether the text is already chunked + - chunk_size - The size of each chunk + - chunk_overlap - The overlap between chunks + - encoding_name - The name of the encoding to use + Returns: + - output - A text splitter + """ + if prechunked: + return NoopTextSplitter() + + return TokenTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + encoding_name=encoding_name, + ) diff --git a/graphrag/index/operations/extract_entities/nltk_strategy.py b/graphrag/index/operations/extract_entities/nltk_strategy.py new file mode 100644 index 00000000..81103c69 --- /dev/null +++ b/graphrag/index/operations/extract_entities/nltk_strategy.py @@ -0,0 +1,66 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run method definition.""" + +import networkx as nx +import nltk +from datashaper import VerbCallbacks +from nltk.corpus import words + +from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.index.operations.extract_entities.typing import ( + Document, + EntityExtractionResult, + EntityTypes, + StrategyConfig, +) + +# Need to do this cause we're potentially multithreading, and nltk doesn't like that +words.ensure_loaded() + + +async def run( # noqa RUF029 async is required for interface + docs: list[Document], + entity_types: EntityTypes, + callbacks: VerbCallbacks, # noqa ARG001 + cache: PipelineCache, # noqa ARG001 + args: StrategyConfig, # noqa ARG001 +) -> EntityExtractionResult: + """Run method definition.""" + entity_map = {} + graph = nx.Graph() + for doc in docs: + connected_entities = [] + for chunk in nltk.ne_chunk(nltk.pos_tag(nltk.word_tokenize(doc.text))): + if hasattr(chunk, "label"): + entity_type = chunk.label().lower() + if entity_type in entity_types: + name = (" ".join(c[0] for c in chunk)).upper() + connected_entities.append(name) + if name not in entity_map: + entity_map[name] = entity_type + graph.add_node( + name, type=entity_type, description=name, source_id=doc.id + ) + + # connect the entities if they appear in the same document + if len(connected_entities) > 1: + for i in range(len(connected_entities)): + for j in range(i + 1, len(connected_entities)): + description = f"{connected_entities[i]} -> {connected_entities[j]}" + graph.add_edge( + connected_entities[i], + connected_entities[j], + description=description, + source_id=doc.id, + ) + + return EntityExtractionResult( + entities=[ + {"type": entity_type, "title": name} + for name, entity_type in entity_map.items() + ], + relationships=[], + graph=graph, + ) diff --git a/graphrag/index/operations/extract_entities/typing.py b/graphrag/index/operations/extract_entities/typing.py new file mode 100644 index 00000000..7eb24406 --- /dev/null +++ b/graphrag/index/operations/extract_entities/typing.py @@ -0,0 +1,59 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'Document' and 'EntityExtractionResult' models.""" + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from enum import Enum +from typing import Any + +import networkx as nx +from datashaper import VerbCallbacks + +from graphrag.cache.pipeline_cache import PipelineCache + +ExtractedEntity = dict[str, Any] +ExtractedRelationship = dict[str, Any] +StrategyConfig = dict[str, Any] +EntityTypes = list[str] + + +@dataclass +class Document: + """Document class definition.""" + + text: str + id: str + + +@dataclass +class EntityExtractionResult: + """Entity extraction result class definition.""" + + entities: list[ExtractedEntity] + relationships: list[ExtractedRelationship] + graph: nx.Graph | None + + +EntityExtractStrategy = Callable[ + [ + list[Document], + EntityTypes, + VerbCallbacks, + PipelineCache, + StrategyConfig, + ], + Awaitable[EntityExtractionResult], +] + + +class ExtractEntityStrategyType(str, Enum): + """ExtractEntityStrategyType class definition.""" + + graph_intelligence = "graph_intelligence" + nltk = "nltk" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' diff --git a/graphrag/index/operations/layout_graph/typing.py b/graphrag/index/operations/layout_graph/typing.py new file mode 100644 index 00000000..ae46afa9 --- /dev/null +++ b/graphrag/index/operations/layout_graph/typing.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +# Use this for now instead of a wrapper +"""A module containing 'NodePosition' model.""" + +from dataclasses import dataclass + + +@dataclass +class NodePosition: + """Node position class definition.""" + + label: str + cluster: str + size: float + + x: float + y: float + z: float | None = None + + def to_pandas(self) -> tuple[str, float, float, str, float]: + """To pandas method definition.""" + return self.label, self.x, self.y, self.cluster, self.size + + +GraphLayout = list[NodePosition] diff --git a/graphrag/index/operations/layout_graph/umap.py b/graphrag/index/operations/layout_graph/umap.py new file mode 100644 index 00000000..e5ab1668 --- /dev/null +++ b/graphrag/index/operations/layout_graph/umap.py @@ -0,0 +1,136 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run and _create_node_position methods definitions.""" + +import logging +import traceback +from typing import Any + +import networkx as nx +import numpy as np + +from graphrag.index.operations.embed_graph.typing import NodeEmbeddings +from graphrag.index.operations.layout_graph.typing import ( + GraphLayout, + NodePosition, +) +from graphrag.index.typing import ErrorHandlerFn + +# TODO: This could be handled more elegantly, like what columns to use +# for "size" or "cluster" +# We could also have a boolean to indicate to use node sizes or clusters + +log = logging.getLogger(__name__) + + +def run( + graph: nx.Graph, + embeddings: NodeEmbeddings, + args: dict[str, Any], + on_error: ErrorHandlerFn, +) -> GraphLayout: + """Run method definition.""" + node_clusters = [] + node_sizes = [] + + embeddings = _filter_raw_embeddings(embeddings) + nodes = list(embeddings.keys()) + embedding_vectors = [embeddings[node_id] for node_id in nodes] + + for node_id in nodes: + node = graph.nodes[node_id] + cluster = node.get("cluster", node.get("community", -1)) + node_clusters.append(cluster) + size = node.get("degree", node.get("size", 0)) + node_sizes.append(size) + + additional_args = {} + if len(node_clusters) > 0: + additional_args["node_categories"] = node_clusters + if len(node_sizes) > 0: + additional_args["node_sizes"] = node_sizes + + try: + return compute_umap_positions( + embedding_vectors=np.array(embedding_vectors), + node_labels=nodes, + **additional_args, + min_dist=args.get("min_dist", 0.75), + n_neighbors=args.get("n_neighbors", 5), + ) + except Exception as e: + log.exception("Error running UMAP") + on_error(e, traceback.format_exc(), None) + # Umap may fail due to input sparseness or memory pressure. + # For now, in these cases, we'll just return a layout with all nodes at (0, 0) + result = [] + for i in range(len(nodes)): + cluster = node_clusters[i] if len(node_clusters) > 0 else 1 + result.append( + NodePosition(x=0, y=0, label=nodes[i], size=0, cluster=str(cluster)) + ) + return result + + +def _filter_raw_embeddings(embeddings: NodeEmbeddings) -> NodeEmbeddings: + return { + node_id: embedding + for node_id, embedding in embeddings.items() + if embedding is not None + } + + +def compute_umap_positions( + embedding_vectors: np.ndarray, + node_labels: list[str], + node_categories: list[int] | None = None, + node_sizes: list[int] | None = None, + min_dist: float = 0.75, + n_neighbors: int = 25, + spread: int = 1, + metric: str = "euclidean", + n_components: int = 2, + random_state: int = 86, +) -> list[NodePosition]: + """Project embedding vectors down to 2D/3D using UMAP.""" + # NOTE: This import is done here to reduce the initial import time of the graphrag package + import umap + + embedding_positions = umap.UMAP( + min_dist=min_dist, + n_neighbors=n_neighbors, + spread=spread, + n_components=n_components, + metric=metric, + random_state=random_state, + ).fit_transform(embedding_vectors) + + embedding_position_data: list[NodePosition] = [] + for index, node_name in enumerate(node_labels): + node_points = embedding_positions[index] # type: ignore + node_category = 1 if node_categories is None else node_categories[index] + node_size = 1 if node_sizes is None else node_sizes[index] + + if len(node_points) == 2: + embedding_position_data.append( + NodePosition( + label=str(node_name), + x=float(node_points[0]), + y=float(node_points[1]), + cluster=str(int(node_category)), + size=int(node_size), + ) + ) + else: + embedding_position_data.append( + NodePosition( + label=str(node_name), + x=float(node_points[0]), + y=float(node_points[1]), + z=float(node_points[2]), + cluster=str(int(node_category)), + size=int(node_size), + ) + ) + return embedding_position_data diff --git a/graphrag/index/operations/layout_graph/zero.py b/graphrag/index/operations/layout_graph/zero.py new file mode 100644 index 00000000..4bb7d39b --- /dev/null +++ b/graphrag/index/operations/layout_graph/zero.py @@ -0,0 +1,98 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run and _create_node_position methods definitions.""" + +import logging +import traceback +from typing import Any + +import networkx as nx + +from graphrag.index.operations.layout_graph.typing import ( + GraphLayout, + NodePosition, +) +from graphrag.index.typing import ErrorHandlerFn + +# TODO: This could be handled more elegantly, like what columns to use +# for "size" or "cluster" +# We could also have a boolean to indicate to use node sizes or clusters + +log = logging.getLogger(__name__) + + +def run( + graph: nx.Graph, + _args: dict[str, Any], + on_error: ErrorHandlerFn, +) -> GraphLayout: + """Run method definition.""" + node_clusters = [] + node_sizes = [] + + nodes = list(graph.nodes) + + for node_id in nodes: + node = graph.nodes[node_id] + cluster = node.get("cluster", node.get("community", -1)) + node_clusters.append(cluster) + size = node.get("degree", node.get("size", 0)) + node_sizes.append(size) + + additional_args = {} + if len(node_clusters) > 0: + additional_args["node_categories"] = node_clusters + if len(node_sizes) > 0: + additional_args["node_sizes"] = node_sizes + + try: + return get_zero_positions(node_labels=nodes, **additional_args) + except Exception as e: + log.exception("Error running zero-position") + on_error(e, traceback.format_exc(), None) + # Umap may fail due to input sparseness or memory pressure. + # For now, in these cases, we'll just return a layout with all nodes at (0, 0) + result = [] + for i in range(len(nodes)): + cluster = node_clusters[i] if len(node_clusters) > 0 else 1 + result.append( + NodePosition(x=0, y=0, label=nodes[i], size=0, cluster=str(cluster)) + ) + return result + + +def get_zero_positions( + node_labels: list[str], + node_categories: list[int] | None = None, + node_sizes: list[int] | None = None, + three_d: bool | None = False, +) -> list[NodePosition]: + """Project embedding vectors down to 2D/3D using UMAP.""" + embedding_position_data: list[NodePosition] = [] + for index, node_name in enumerate(node_labels): + node_category = 1 if node_categories is None else node_categories[index] + node_size = 1 if node_sizes is None else node_sizes[index] + + if not three_d: + embedding_position_data.append( + NodePosition( + label=str(node_name), + x=0, + y=0, + cluster=str(int(node_category)), + size=int(node_size), + ) + ) + else: + embedding_position_data.append( + NodePosition( + label=str(node_name), + x=0, + y=0, + z=0, + cluster=str(int(node_category)), + size=int(node_size), + ) + ) + return embedding_position_data diff --git a/graphrag/index/operations/summarize_communities/community_reports_extractor/__init__.py b/graphrag/index/operations/summarize_communities/community_reports_extractor/__init__.py new file mode 100644 index 00000000..5daf0df3 --- /dev/null +++ b/graphrag/index/operations/summarize_communities/community_reports_extractor/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine community reports package root.""" + +import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas +from graphrag.index.operations.summarize_communities.community_reports_extractor.build_mixed_context import ( + build_mixed_context, +) +from graphrag.index.operations.summarize_communities.community_reports_extractor.community_reports_extractor import ( + CommunityReportsExtractor, +) +from graphrag.index.operations.summarize_communities.community_reports_extractor.prep_community_report_context import ( + prep_community_report_context, +) +from graphrag.index.operations.summarize_communities.community_reports_extractor.sort_context import ( + sort_context, +) + +__all__ = [ + "CommunityReportsExtractor", + "build_mixed_context", + "prep_community_report_context", + "schemas", + "sort_context", +] diff --git a/graphrag/index/operations/summarize_communities/community_reports_extractor/build_mixed_context.py b/graphrag/index/operations/summarize_communities/community_reports_extractor/build_mixed_context.py new file mode 100644 index 00000000..32f19f77 --- /dev/null +++ b/graphrag/index/operations/summarize_communities/community_reports_extractor/build_mixed_context.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""A module containing the build_mixed_context method definition.""" + +import pandas as pd + +import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas +from graphrag.index.operations.summarize_communities.community_reports_extractor.sort_context import ( + sort_context, +) +from graphrag.query.llm.text_utils import num_tokens + + +def build_mixed_context(context: list[dict], max_tokens: int) -> str: + """ + Build parent context by concatenating all sub-communities' contexts. + + If the context exceeds the limit, we use sub-community reports instead. + """ + sorted_context = sorted( + context, key=lambda x: x[schemas.CONTEXT_SIZE], reverse=True + ) + + # replace local context with sub-community reports, starting from the biggest sub-community + substitute_reports = [] + final_local_contexts = [] + exceeded_limit = True + context_string = "" + + for idx, sub_community_context in enumerate(sorted_context): + if exceeded_limit: + if sub_community_context[schemas.FULL_CONTENT]: + substitute_reports.append({ + schemas.COMMUNITY_ID: sub_community_context[schemas.SUB_COMMUNITY], + schemas.FULL_CONTENT: sub_community_context[schemas.FULL_CONTENT], + }) + else: + # this sub-community has no report, so we will use its local context + final_local_contexts.extend(sub_community_context[schemas.ALL_CONTEXT]) + continue + + # add local context for the remaining sub-communities + remaining_local_context = [] + for rid in range(idx + 1, len(sorted_context)): + remaining_local_context.extend(sorted_context[rid][schemas.ALL_CONTEXT]) + new_context_string = sort_context( + local_context=remaining_local_context + final_local_contexts, + sub_community_reports=substitute_reports, + ) + if num_tokens(new_context_string) <= max_tokens: + exceeded_limit = False + context_string = new_context_string + break + + if exceeded_limit: + # if all sub-community reports exceed the limit, we add reports until context is full + substitute_reports = [] + for sub_community_context in sorted_context: + substitute_reports.append({ + schemas.COMMUNITY_ID: sub_community_context[schemas.SUB_COMMUNITY], + schemas.FULL_CONTENT: sub_community_context[schemas.FULL_CONTENT], + }) + new_context_string = pd.DataFrame(substitute_reports).to_csv( + index=False, sep="," + ) + if num_tokens(new_context_string) > max_tokens: + break + + context_string = new_context_string + return context_string diff --git a/graphrag/index/operations/summarize_communities/community_reports_extractor/community_reports_extractor.py b/graphrag/index/operations/summarize_communities/community_reports_extractor/community_reports_extractor.py new file mode 100644 index 00000000..7fa0b684 --- /dev/null +++ b/graphrag/index/operations/summarize_communities/community_reports_extractor/community_reports_extractor.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'CommunityReportsResult' and 'CommunityReportsExtractor' models.""" + +import logging +import traceback +from dataclasses import dataclass +from typing import Any + +from fnllm import ChatLLM +from pydantic import BaseModel, Field + +from graphrag.index.typing import ErrorHandlerFn +from graphrag.prompts.index.community_report import COMMUNITY_REPORT_PROMPT + +log = logging.getLogger(__name__) + + +class FindingModel(BaseModel): + """A model for the expected LLM response shape.""" + + summary: str = Field(description="The summary of the finding.") + explanation: str = Field(description="An explanation of the finding.") + + +class CommunityReportResponse(BaseModel): + """A model for the expected LLM response shape.""" + + title: str = Field(description="The title of the report.") + summary: str = Field(description="A summary of the report.") + findings: list[FindingModel] = Field( + description="A list of findings in the report." + ) + rating: float = Field(description="The rating of the report.") + rating_explanation: str = Field(description="An explanation of the rating.") + + extra_attributes: dict[str, Any] = Field( + default_factory=dict, description="Extra attributes." + ) + + +@dataclass +class CommunityReportsResult: + """Community reports result class definition.""" + + output: str + structured_output: CommunityReportResponse | None + + +class CommunityReportsExtractor: + """Community reports extractor class definition.""" + + _llm: ChatLLM + _input_text_key: str + _extraction_prompt: str + _output_formatter_prompt: str + _on_error: ErrorHandlerFn + _max_report_length: int + + def __init__( + self, + llm_invoker: ChatLLM, + input_text_key: str | None = None, + extraction_prompt: str | None = None, + on_error: ErrorHandlerFn | None = None, + max_report_length: int | None = None, + ): + """Init method definition.""" + self._llm = llm_invoker + self._input_text_key = input_text_key or "input_text" + self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT + self._on_error = on_error or (lambda _e, _s, _d: None) + self._max_report_length = max_report_length or 1500 + + async def __call__(self, inputs: dict[str, Any]): + """Call method definition.""" + output = None + try: + input_text = inputs[self._input_text_key] + prompt = self._extraction_prompt.replace( + "{" + self._input_text_key + "}", input_text + ) + response = await self._llm( + prompt, + json=True, + name="create_community_report", + json_model=CommunityReportResponse, + model_parameters={"max_tokens": self._max_report_length}, + ) + output = response.parsed_json + except Exception as e: + log.exception("error generating community report") + self._on_error(e, traceback.format_exc(), None) + + text_output = self._get_text_output(output) if output else "" + return CommunityReportsResult( + structured_output=output, + output=text_output, + ) + + def _get_text_output(self, report: CommunityReportResponse) -> str: + report_sections = "\n\n".join( + f"## {f.summary}\n\n{f.explanation}" for f in report.findings + ) + return f"# {report.title}\n\n{report.summary}\n\n{report_sections}" diff --git a/graphrag/index/operations/summarize_communities/community_reports_extractor/prep_community_report_context.py b/graphrag/index/operations/summarize_communities/community_reports_extractor/prep_community_report_context.py new file mode 100644 index 00000000..bb5125f1 --- /dev/null +++ b/graphrag/index/operations/summarize_communities/community_reports_extractor/prep_community_report_context.py @@ -0,0 +1,185 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_community_reports and load_strategy methods definition.""" + +import logging +from typing import cast + +import pandas as pd + +import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas +from graphrag.index.operations.summarize_communities.community_reports_extractor.build_mixed_context import ( + build_mixed_context, +) +from graphrag.index.operations.summarize_communities.community_reports_extractor.sort_context import ( + sort_context, +) +from graphrag.index.utils.dataframes import ( + antijoin, + drop_columns, + join, + select, + transform_series, + union, + where_column_equals, +) +from graphrag.query.llm.text_utils import num_tokens + +log = logging.getLogger(__name__) + + +def prep_community_report_context( + report_df: pd.DataFrame | None, + community_hierarchy_df: pd.DataFrame, + local_context_df: pd.DataFrame, + level: int, + max_tokens: int, +) -> pd.DataFrame: + """ + Prep context for each community in a given level. + + For each community: + - Check if local context fits within the limit, if yes use local context + - If local context exceeds the limit, iteratively replace local context with sub-community reports, starting from the biggest sub-community + """ + if report_df is None: + report_df = pd.DataFrame() + + # Filter by community level + level_context_df = local_context_df.loc[ + local_context_df.loc[:, schemas.COMMUNITY_LEVEL] == level + ] + + # Filter valid and invalid contexts using boolean logic + valid_context_df = level_context_df.loc[ + ~level_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG] + ] + invalid_context_df = level_context_df.loc[ + level_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG] + ] + + # there is no report to substitute with, so we just trim the local context of the invalid context records + # this case should only happen at the bottom level of the community hierarchy where there are no sub-communities + if invalid_context_df.empty: + return valid_context_df + + if report_df.empty: + invalid_context_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context( + invalid_context_df, max_tokens + ) + invalid_context_df[schemas.CONTEXT_SIZE] = invalid_context_df.loc[ + :, schemas.CONTEXT_STRING + ].map(num_tokens) + invalid_context_df[schemas.CONTEXT_EXCEED_FLAG] = 0 + return union(valid_context_df, invalid_context_df) + + level_context_df = _antijoin_reports(level_context_df, report_df) + + # for each invalid context, we will try to substitute with sub-community reports + # first get local context and report (if available) for each sub-community + sub_context_df = _get_subcontext_df(level + 1, report_df, local_context_df) + community_df = _get_community_df( + level, invalid_context_df, sub_context_df, community_hierarchy_df, max_tokens + ) + + # handle any remaining invalid records that can't be subsituted with sub-community reports + # this should be rare, but if it happens, we will just trim the local context to fit the limit + remaining_df = _antijoin_reports(invalid_context_df, community_df) + remaining_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context( + remaining_df, max_tokens + ) + + result = union(valid_context_df, community_df, remaining_df) + result[schemas.CONTEXT_SIZE] = result.loc[:, schemas.CONTEXT_STRING].map(num_tokens) + + result[schemas.CONTEXT_EXCEED_FLAG] = 0 + return result + + +def _drop_community_level(df: pd.DataFrame) -> pd.DataFrame: + """Drop the community level column from the dataframe.""" + return drop_columns(df, schemas.COMMUNITY_LEVEL) + + +def _at_level(level: int, df: pd.DataFrame) -> pd.DataFrame: + """Return records at the given level.""" + return where_column_equals(df, schemas.COMMUNITY_LEVEL, level) + + +def _antijoin_reports(df: pd.DataFrame, reports: pd.DataFrame) -> pd.DataFrame: + """Return records in df that are not in reports.""" + return antijoin(df, reports, schemas.NODE_COMMUNITY) + + +def _sort_and_trim_context(df: pd.DataFrame, max_tokens: int) -> pd.Series: + """Sort and trim context to fit the limit.""" + series = cast("pd.Series", df[schemas.ALL_CONTEXT]) + return transform_series(series, lambda x: sort_context(x, max_tokens=max_tokens)) + + +def _build_mixed_context(df: pd.DataFrame, max_tokens: int) -> pd.Series: + """Sort and trim context to fit the limit.""" + series = cast("pd.Series", df[schemas.ALL_CONTEXT]) + return transform_series( + series, lambda x: build_mixed_context(x, max_tokens=max_tokens) + ) + + +def _get_subcontext_df( + level: int, report_df: pd.DataFrame, local_context_df: pd.DataFrame +) -> pd.DataFrame: + """Get sub-community context for each community.""" + sub_report_df = _drop_community_level(_at_level(level, report_df)) + sub_context_df = _at_level(level, local_context_df) + sub_context_df = join(sub_context_df, sub_report_df, schemas.NODE_COMMUNITY) + sub_context_df.rename( + columns={schemas.NODE_COMMUNITY: schemas.SUB_COMMUNITY}, inplace=True + ) + return sub_context_df + + +def _get_community_df( + level: int, + invalid_context_df: pd.DataFrame, + sub_context_df: pd.DataFrame, + community_hierarchy_df: pd.DataFrame, + max_tokens: int, +) -> pd.DataFrame: + """Get community context for each community.""" + # collect all sub communities' contexts for each community + community_df = _drop_community_level(_at_level(level, community_hierarchy_df)) + invalid_community_ids = select(invalid_context_df, schemas.NODE_COMMUNITY) + subcontext_selection = select( + sub_context_df, + schemas.SUB_COMMUNITY, + schemas.FULL_CONTENT, + schemas.ALL_CONTEXT, + schemas.CONTEXT_SIZE, + ) + + invalid_communities = join( + community_df, invalid_community_ids, schemas.NODE_COMMUNITY, "inner" + ) + community_df = join( + invalid_communities, subcontext_selection, schemas.SUB_COMMUNITY + ) + community_df[schemas.ALL_CONTEXT] = community_df.apply( + lambda x: { + schemas.SUB_COMMUNITY: x[schemas.SUB_COMMUNITY], + schemas.ALL_CONTEXT: x[schemas.ALL_CONTEXT], + schemas.FULL_CONTENT: x[schemas.FULL_CONTENT], + schemas.CONTEXT_SIZE: x[schemas.CONTEXT_SIZE], + }, + axis=1, + ) + community_df = ( + community_df.groupby(schemas.NODE_COMMUNITY) + .agg({schemas.ALL_CONTEXT: list}) + .reset_index() + ) + community_df[schemas.CONTEXT_STRING] = _build_mixed_context( + community_df, max_tokens + ) + community_df[schemas.COMMUNITY_LEVEL] = level + return community_df diff --git a/graphrag/index/operations/summarize_communities/community_reports_extractor/schemas.py b/graphrag/index/operations/summarize_communities/community_reports_extractor/schemas.py new file mode 100644 index 00000000..9c8e29b6 --- /dev/null +++ b/graphrag/index/operations/summarize_communities/community_reports_extractor/schemas.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""Common field name definitions for community reports.""" + +# POST-PREP NODE TABLE SCHEMA +NODE_ID = "human_readable_id" +NODE_NAME = "title" +NODE_DESCRIPTION = "description" +NODE_DEGREE = "degree" +NODE_DETAILS = "node_details" +NODE_COMMUNITY = "community" +NODE_LEVEL = "level" + +# POST-PREP EDGE TABLE SCHEMA +EDGE_ID = "human_readable_id" +EDGE_SOURCE = "source" +EDGE_TARGET = "target" +EDGE_DESCRIPTION = "description" +EDGE_DEGREE = "combined_degree" +EDGE_DETAILS = "edge_details" +EDGE_WEIGHT = "weight" + +# POST-PREP CLAIM TABLE SCHEMA +CLAIM_ID = "human_readable_id" +CLAIM_SUBJECT = "subject_id" +CLAIM_TYPE = "type" +CLAIM_STATUS = "status" +CLAIM_DESCRIPTION = "description" +CLAIM_DETAILS = "claim_details" + +# COMMUNITY HIERARCHY TABLE SCHEMA +SUB_COMMUNITY = "sub_community" +SUB_COMMUNITY_SIZE = "sub_community_size" +COMMUNITY_LEVEL = "level" + +# COMMUNITY CONTEXT TABLE SCHEMA +ALL_CONTEXT = "all_context" +CONTEXT_STRING = "context_string" +CONTEXT_SIZE = "context_size" +CONTEXT_EXCEED_FLAG = "context_exceed_limit" + +# COMMUNITY REPORT TABLE SCHEMA +REPORT_ID = "id" +COMMUNITY_ID = "community" +COMMUNITY_LEVEL = "level" +TITLE = "title" +SUMMARY = "summary" +FINDINGS = "findings" +RATING = "rank" +EXPLANATION = "rating_explanation" +FULL_CONTENT = "full_content" +FULL_CONTENT_JSON = "full_content_json" diff --git a/graphrag/index/operations/summarize_communities/community_reports_extractor/sort_context.py b/graphrag/index/operations/summarize_communities/community_reports_extractor/sort_context.py new file mode 100644 index 00000000..cd17578b --- /dev/null +++ b/graphrag/index/operations/summarize_communities/community_reports_extractor/sort_context.py @@ -0,0 +1,154 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""Sort context by degree in descending order.""" + +import pandas as pd + +import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas +from graphrag.query.llm.text_utils import num_tokens + + +def sort_context( + local_context: list[dict], + sub_community_reports: list[dict] | None = None, + max_tokens: int | None = None, + node_name_column: str = schemas.NODE_NAME, + node_details_column: str = schemas.NODE_DETAILS, + edge_id_column: str = schemas.EDGE_ID, + edge_details_column: str = schemas.EDGE_DETAILS, + edge_degree_column: str = schemas.EDGE_DEGREE, + edge_source_column: str = schemas.EDGE_SOURCE, + edge_target_column: str = schemas.EDGE_TARGET, + claim_details_column: str = schemas.CLAIM_DETAILS, +) -> str: + """Sort context by degree in descending order, optimizing for performance.""" + + def _get_context_string( + entities: list[dict], + edges: list[dict], + claims: list[dict], + sub_community_reports: list[dict] | None = None, + ) -> str: + """Concatenate structured data into a context string.""" + contexts = [] + if sub_community_reports: + report_df = pd.DataFrame(sub_community_reports) + if not report_df.empty: + contexts.append( + f"----Reports-----\n{report_df.to_csv(index=False, sep=',')}" + ) + + for label, data in [ + ("Entities", entities), + ("Claims", claims), + ("Relationships", edges), + ]: + if data: + data_df = pd.DataFrame(data) + if not data_df.empty: + contexts.append( + f"-----{label}-----\n{data_df.to_csv(index=False, sep=',')}" + ) + + return "\n\n".join(contexts) + + # Preprocess local context + edges = [ + {**e, schemas.EDGE_ID: int(e[schemas.EDGE_ID])} + for record in local_context + for e in record.get(edge_details_column, []) + if isinstance(e, dict) + ] + + node_details = { + record[node_name_column]: { + **record[node_details_column], + schemas.NODE_ID: int(record[node_details_column][schemas.NODE_ID]), + } + for record in local_context + } + + claim_details = { + record[node_name_column]: [ + {**c, schemas.CLAIM_ID: int(c[schemas.CLAIM_ID])} + for c in record.get(claim_details_column, []) + if isinstance(c, dict) and c.get(schemas.CLAIM_ID) is not None + ] + for record in local_context + if isinstance(record.get(claim_details_column), list) + } + + # Sort edges by degree (desc) and ID (asc) + edges.sort(key=lambda x: (-x.get(edge_degree_column, 0), x.get(edge_id_column, ""))) + + # Deduplicate and build context incrementally + edge_ids, nodes_ids, claims_ids = set(), set(), set() + sorted_edges, sorted_nodes, sorted_claims = [], [], [] + context_string = "" + + for edge in edges: + source, target = edge[edge_source_column], edge[edge_target_column] + + # Add source and target node details + for node in [node_details.get(source), node_details.get(target)]: + if node and node[schemas.NODE_ID] not in nodes_ids: + nodes_ids.add(node[schemas.NODE_ID]) + sorted_nodes.append(node) + + # Add claims related to source and target + for claims in [claim_details.get(source), claim_details.get(target)]: + if claims: + for claim in claims: + if claim[schemas.CLAIM_ID] not in claims_ids: + claims_ids.add(claim[schemas.CLAIM_ID]) + sorted_claims.append(claim) + + # Add the edge + if edge[schemas.EDGE_ID] not in edge_ids: + edge_ids.add(edge[schemas.EDGE_ID]) + sorted_edges.append(edge) + + # Generate new context string + new_context_string = _get_context_string( + sorted_nodes, sorted_edges, sorted_claims, sub_community_reports + ) + if max_tokens and num_tokens(new_context_string) > max_tokens: + break + context_string = new_context_string + + # Return the final context string + return context_string or _get_context_string( + sorted_nodes, sorted_edges, sorted_claims, sub_community_reports + ) + + +def parallel_sort_context_batch(community_df, max_tokens, parallel=False): + """Calculate context using parallelization if enabled.""" + if parallel: + # Use ThreadPoolExecutor for parallel execution + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=None) as executor: + context_strings = list( + executor.map( + lambda x: sort_context(x, max_tokens=max_tokens), + community_df[schemas.ALL_CONTEXT], + ) + ) + community_df[schemas.CONTEXT_STRING] = context_strings + + else: + # Assign context strings directly to the DataFrame + community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply( + lambda context_list: sort_context(context_list, max_tokens=max_tokens) + ) + + # Calculate other columns + community_df[schemas.CONTEXT_SIZE] = community_df[schemas.CONTEXT_STRING].apply( + num_tokens + ) + community_df[schemas.CONTEXT_EXCEED_FLAG] = ( + community_df[schemas.CONTEXT_SIZE] > max_tokens + ) + + return community_df diff --git a/graphrag/index/operations/summarize_communities/community_reports_extractor/utils.py b/graphrag/index/operations/summarize_communities/community_reports_extractor/utils.py new file mode 100644 index 00000000..c847451b --- /dev/null +++ b/graphrag/index/operations/summarize_communities/community_reports_extractor/utils.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing community report generation utilities.""" + +import pandas as pd + +import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas + + +def get_levels(df: pd.DataFrame, level_column: str = schemas.NODE_LEVEL) -> list[int]: + """Get the levels of the communities.""" + levels = df[level_column].dropna().unique() + levels = [int(lvl) for lvl in levels if lvl != -1] + return sorted(levels, reverse=True) diff --git a/graphrag/index/operations/summarize_descriptions/description_summary_extractor.py b/graphrag/index/operations/summarize_descriptions/description_summary_extractor.py new file mode 100644 index 00000000..ee59c649 --- /dev/null +++ b/graphrag/index/operations/summarize_descriptions/description_summary_extractor.py @@ -0,0 +1,140 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'GraphExtractionResult' and 'GraphExtractor' models.""" + +import json +from dataclasses import dataclass + +from fnllm import ChatLLM + +from graphrag.index.typing import ErrorHandlerFn +from graphrag.index.utils.tokens import num_tokens_from_string +from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT + +# Max token size for input prompts +DEFAULT_MAX_INPUT_TOKENS = 4_000 +# Max token count for LLM answers +DEFAULT_MAX_SUMMARY_LENGTH = 500 + + +@dataclass +class SummarizationResult: + """Unipartite graph extraction result class definition.""" + + id: str | tuple[str, str] + description: str + + +class SummarizeExtractor: + """Unipartite graph extractor class definition.""" + + _llm: ChatLLM + _entity_name_key: str + _input_descriptions_key: str + _summarization_prompt: str + _on_error: ErrorHandlerFn + _max_summary_length: int + _max_input_tokens: int + + def __init__( + self, + llm_invoker: ChatLLM, + entity_name_key: str | None = None, + input_descriptions_key: str | None = None, + summarization_prompt: str | None = None, + on_error: ErrorHandlerFn | None = None, + max_summary_length: int | None = None, + max_input_tokens: int | None = None, + ): + """Init method definition.""" + # TODO: streamline construction + self._llm = llm_invoker + self._entity_name_key = entity_name_key or "entity_name" + self._input_descriptions_key = input_descriptions_key or "description_list" + + self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT + self._on_error = on_error or (lambda _e, _s, _d: None) + self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH + self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS + + async def __call__( + self, + id: str | tuple[str, str], + descriptions: list[str], + ) -> SummarizationResult: + """Call method definition.""" + result = "" + if len(descriptions) == 0: + result = "" + elif len(descriptions) == 1: + result = descriptions[0] + else: + result = await self._summarize_descriptions(id, descriptions) + + return SummarizationResult( + id=id, + description=result or "", + ) + + async def _summarize_descriptions( + self, id: str | tuple[str, str], descriptions: list[str] + ) -> str: + """Summarize descriptions into a single description.""" + sorted_id = sorted(id) if isinstance(id, list) else id + + # Safety check, should always be a list + if not isinstance(descriptions, list): + descriptions = [descriptions] + + # Sort description lists + if len(descriptions) > 1: + descriptions = sorted(descriptions) + + # Iterate over descriptions, adding all until the max input tokens is reached + usable_tokens = self._max_input_tokens - num_tokens_from_string( + self._summarization_prompt + ) + descriptions_collected = [] + result = "" + + for i, description in enumerate(descriptions): + usable_tokens -= num_tokens_from_string(description) + descriptions_collected.append(description) + + # If buffer is full, or all descriptions have been added, summarize + if (usable_tokens < 0 and len(descriptions_collected) > 1) or ( + i == len(descriptions) - 1 + ): + # Calculate result (final or partial) + result = await self._summarize_descriptions_with_llm( + sorted_id, descriptions_collected + ) + + # If we go for another loop, reset values to new + if i != len(descriptions) - 1: + descriptions_collected = [result] + usable_tokens = ( + self._max_input_tokens + - num_tokens_from_string(self._summarization_prompt) + - num_tokens_from_string(result) + ) + + return result + + async def _summarize_descriptions_with_llm( + self, id: str | tuple[str, str] | list[str], descriptions: list[str] + ): + """Summarize descriptions using the LLM.""" + response = await self._llm( + self._summarization_prompt.format(**{ + self._entity_name_key: json.dumps(id, ensure_ascii=False), + self._input_descriptions_key: json.dumps( + sorted(descriptions), ensure_ascii=False + ), + }), + name="summarize", + model_parameters={"max_tokens": self._max_summary_length}, + ) + # Calculate result + return str(response.output.content) diff --git a/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py b/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py new file mode 100644 index 00000000..4a22b9b5 --- /dev/null +++ b/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py @@ -0,0 +1,64 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_graph_intelligence, run_resolve_entities and _create_text_list_splitter methods to run graph intelligence.""" + +from datashaper import VerbCallbacks +from fnllm import ChatLLM + +from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.index.llm.load_llm import load_llm, read_llm_params +from graphrag.index.operations.summarize_descriptions.description_summary_extractor import ( + SummarizeExtractor, +) +from graphrag.index.operations.summarize_descriptions.typing import ( + StrategyConfig, + SummarizedDescriptionResult, +) + + +async def run_graph_intelligence( + id: str | tuple[str, str], + descriptions: list[str], + callbacks: VerbCallbacks, + cache: PipelineCache, + args: StrategyConfig, +) -> SummarizedDescriptionResult: + """Run the graph intelligence entity extraction strategy.""" + llm_config = read_llm_params(args.get("llm", {})) + llm = load_llm( + "summarize_descriptions", llm_config, callbacks=callbacks, cache=cache + ) + return await run_summarize_descriptions(llm, id, descriptions, callbacks, args) + + +async def run_summarize_descriptions( + llm: ChatLLM, + id: str | tuple[str, str], + descriptions: list[str], + callbacks: VerbCallbacks, + args: StrategyConfig, +) -> SummarizedDescriptionResult: + """Run the entity extraction chain.""" + # Extraction Arguments + summarize_prompt = args.get("summarize_prompt", None) + entity_name_key = args.get("entity_name_key", "entity_name") + input_descriptions_key = args.get("input_descriptions_key", "description_list") + max_tokens = args.get("max_tokens", None) + + extractor = SummarizeExtractor( + llm_invoker=llm, + summarization_prompt=summarize_prompt, + entity_name_key=entity_name_key, + input_descriptions_key=input_descriptions_key, + on_error=lambda e, stack, details: ( + callbacks.error("Entity Extraction Error", e, stack, details) + if callbacks + else None + ), + max_summary_length=args.get("max_summary_length", None), + max_input_tokens=max_tokens, + ) + + result = await extractor(id=id, descriptions=descriptions) + return SummarizedDescriptionResult(id=result.id, description=result.description) diff --git a/graphrag/index/utils/stable_lcc.py b/graphrag/index/utils/stable_lcc.py new file mode 100644 index 00000000..07031133 --- /dev/null +++ b/graphrag/index/utils/stable_lcc.py @@ -0,0 +1,67 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module for producing a stable largest connected component, i.e. same input graph == same output lcc.""" + +import html +from typing import Any, cast + +import networkx as nx + + +def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: + """Return the largest connected component of the graph, with nodes and edges sorted in a stable way.""" + # NOTE: The import is done here to reduce the initial import time of the module + from graspologic.utils import largest_connected_component + + graph = graph.copy() + graph = cast("nx.Graph", largest_connected_component(graph)) + graph = normalize_node_names(graph) + return _stabilize_graph(graph) + + +def _stabilize_graph(graph: nx.Graph) -> nx.Graph: + """Ensure an undirected graph with the same relationships will always be read the same way.""" + fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() + + sorted_nodes = graph.nodes(data=True) + sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0]) + + fixed_graph.add_nodes_from(sorted_nodes) + edges = list(graph.edges(data=True)) + + # If the graph is undirected, we create the edges in a stable way, so we get the same results + # for example: + # A -> B + # in graph theory is the same as + # B -> A + # in an undirected graph + # however, this can lead to downstream issues because sometimes + # consumers read graph.nodes() which ends up being [A, B] and sometimes it's [B, A] + # but they base some of their logic on the order of the nodes, so the order ends up being important + # so we sort the nodes in the edge in a stable way, so that we always get the same order + if not graph.is_directed(): + + def _sort_source_target(edge): + source, target, edge_data = edge + if source > target: + temp = source + source = target + target = temp + return source, target, edge_data + + edges = [_sort_source_target(edge) for edge in edges] + + def _get_edge_key(source: Any, target: Any) -> str: + return f"{source} -> {target}" + + edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) + + fixed_graph.add_edges_from(edges) + return fixed_graph + + +def normalize_node_names(graph: nx.Graph | nx.DiGraph) -> nx.Graph | nx.DiGraph: + """Normalize node names.""" + node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore + return nx.relabel_nodes(graph, node_mapping) diff --git a/graphrag/py.typed b/graphrag/py.typed new file mode 100644 index 00000000..f4bd2989 --- /dev/null +++ b/graphrag/py.typed @@ -0,0 +1,2 @@ +# This package supports type hinting, +# see https://www.python.org/dev/peps/pep-0561/#packaging-type-information \ No newline at end of file diff --git a/tests/integration/storage/test_cosmosdb_storage.py b/tests/integration/storage/test_cosmosdb_storage.py new file mode 100644 index 00000000..5b82e0d8 --- /dev/null +++ b/tests/integration/storage/test_cosmosdb_storage.py @@ -0,0 +1,109 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""CosmosDB Storage Tests.""" + +import json +import re +import sys + +import pytest + +from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage + +# cspell:disable-next-line well-known-key +WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==" + +# the cosmosdb emulator is only available on windows runners at this time +if not sys.platform.startswith("win"): + pytest.skip( + "encountered windows-only tests -- will skip for now", allow_module_level=True + ) + + +async def test_find(): + storage = CosmosDBPipelineStorage( + connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, + database_name="testfind", + container_name="testfindcontainer", + ) + try: + try: + items = list(storage.find(file_pattern=re.compile(r".*\.json$"))) + items = [item[0] for item in items] + assert items == [] + + json_content = { + "content": "Merry Christmas!", + } + await storage.set( + "christmas.json", json.dumps(json_content), encoding="utf-8" + ) + items = list(storage.find(file_pattern=re.compile(r".*\.json$"))) + items = [item[0] for item in items] + assert items == ["christmas.json"] + + json_content = { + "content": "Hello, World!", + } + await storage.set("test.json", json.dumps(json_content), encoding="utf-8") + items = list(storage.find(file_pattern=re.compile(r".*\.json$"))) + items = [item[0] for item in items] + assert items == ["christmas.json", "test.json"] + + output = await storage.get("test.json") + output_json = json.loads(output) + assert output_json["content"] == "Hello, World!" + + json_exists = await storage.has("christmas.json") + assert json_exists is True + json_exists = await storage.has("easter.json") + assert json_exists is False + finally: + await storage.delete("test.json") + output = await storage.get("test.json") + assert output is None + finally: + await storage.clear() + + +async def test_child(): + storage = CosmosDBPipelineStorage( + connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, + database_name="testchild", + container_name="testchildcontainer", + ) + try: + child_storage = storage.child("child") + assert type(child_storage) is CosmosDBPipelineStorage + finally: + await storage.clear() + + +async def test_clear(): + storage = CosmosDBPipelineStorage( + connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, + database_name="testclear", + container_name="testclearcontainer", + ) + try: + json_exists = { + "content": "Merry Christmas!", + } + await storage.set("christmas.json", json.dumps(json_exists), encoding="utf-8") + json_exists = { + "content": "Happy Easter!", + } + await storage.set("easter.json", json.dumps(json_exists), encoding="utf-8") + await storage.clear() + + items = list(storage.find(file_pattern=re.compile(r".*\.json$"))) + items = [item[0] for item in items] + assert items == [] + + output = await storage.get("easter.json") + assert output is None + + assert storage._container_client is None # noqa: SLF001 + assert storage._database_client is None # noqa: SLF001 + finally: + await storage.clear()