mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Fix a bunch of module comments and function visibility
This commit is contained in:
parent
3201f28bea
commit
76f9862465
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'JsonPipelineCache' model."""
|
||||
"""A module containing 'JsonCache' model."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing the WorkflowCallbacks registry."""
|
||||
"""A module containing 'WorkflowCallbacksManager' model."""
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing load method definition."""
|
||||
"""A module containing 'CSVFileReader' model."""
|
||||
|
||||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing create_input method definition."""
|
||||
"""A module containing 'InputReaderFactory' model."""
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'PipelineCache' model."""
|
||||
"""A module containing 'InputReader' model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing load method definition."""
|
||||
"""A module containing 'JSONFileReader' model."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing load method definition."""
|
||||
"""A module containing 'TextFileReader' model."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing _get_num_total, chunk, run_strategy and load_strategy methods definitions."""
|
||||
"""A module containing chunk_text method definitions."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
@ -54,7 +54,7 @@ def chunk_text(
|
||||
strategy: sentence
|
||||
```
|
||||
"""
|
||||
strategy_exec = load_strategy(strategy)
|
||||
strategy_exec = _load_strategy(strategy)
|
||||
|
||||
num_total = _get_num_total(input, column)
|
||||
tick = progress_ticker(callbacks.progress, num_total)
|
||||
@ -67,7 +67,7 @@ def chunk_text(
|
||||
input.apply(
|
||||
cast(
|
||||
"Any",
|
||||
lambda x: run_strategy(
|
||||
lambda x: _run_strategy(
|
||||
strategy_exec,
|
||||
x[column],
|
||||
config,
|
||||
@ -79,7 +79,7 @@ def chunk_text(
|
||||
)
|
||||
|
||||
|
||||
def run_strategy(
|
||||
def _run_strategy(
|
||||
strategy_exec: ChunkStrategy,
|
||||
input: ChunkInput,
|
||||
config: ChunkingConfig,
|
||||
@ -111,7 +111,7 @@ def run_strategy(
|
||||
return results
|
||||
|
||||
|
||||
def load_strategy(strategy: ChunkStrategyType) -> ChunkStrategy:
|
||||
def _load_strategy(strategy: ChunkStrategyType) -> ChunkStrategy:
|
||||
"""Load strategy method definition."""
|
||||
match strategy:
|
||||
case ChunkStrategyType.tokens:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing chunk strategies."""
|
||||
"""A module containing run_tokens and run_sentences methods."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing cluster_graph, apply_clustering methods definition."""
|
||||
"""A module containing cluster_graph method definition."""
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing create_graph definition."""
|
||||
"""A module containing compute_degree method definition."""
|
||||
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing compute_edge_combined_degree methods definition."""
|
||||
"""A module containing compute_edge_combined_degree method definition."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing embed_text, load_strategy and create_row_from_embedding_data methods definition."""
|
||||
"""A module containing embed_text method definition."""
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing run method definition."""
|
||||
"""A module containing 'TextEmbeddingResult' model and run_embed_text method definition."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing entity_extract methods."""
|
||||
"""A module containing extract_graph method."""
|
||||
|
||||
import logging
|
||||
|
||||
@ -35,7 +35,7 @@ async def extract_graph(
|
||||
nonlocal num_started
|
||||
text = row[text_column]
|
||||
id = row[id_column]
|
||||
result = await run_extract_graph(
|
||||
result = await _run_extract_graph(
|
||||
text=text,
|
||||
source_id=id,
|
||||
entity_types=entity_types,
|
||||
@ -68,7 +68,7 @@ async def extract_graph(
|
||||
return (entities, relationships)
|
||||
|
||||
|
||||
async def run_extract_graph(
|
||||
async def _run_extract_graph(
|
||||
text: str,
|
||||
source_id: str,
|
||||
entity_types: list[str],
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing create_graph definition."""
|
||||
"""A module containing graph_to_dataframes method definition."""
|
||||
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""A module containing the build_mixed_context method definition."""
|
||||
|
||||
"""A module containing build_mixed_context method definition."""
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing create_community_reports and load_strategy methods definition."""
|
||||
"""A module containing summarize_communities method definition."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'GraphExtractionResult' and 'GraphExtractor' models."""
|
||||
"""A module containing 'SummarizationResult' and 'SummarizeExtractor' models."""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing the 'Tokenizer', 'TextSplitter', 'NoopTextSplitter' and 'TokenTextSplitter' models."""
|
||||
"""A module containing 'TokenChunkerOptions', 'TextSplitter', 'NoopTextSplitter', 'TokenTextSplitter', 'split_single_text_on_tokens', and 'split_multiple_texts_on_tokens'."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Apply a generic transform function to each row in a table."""
|
||||
"""A module containing derive_from_rows, derive_from_rows_asyncio_threads, and derive_from_rows_asyncio methods."""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
@ -55,9 +55,6 @@ async def derive_from_rows(
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
"""A module containing the derive_from_rows_async method."""
|
||||
|
||||
|
||||
async def derive_from_rows_asyncio_threads(
|
||||
input: pd.DataFrame,
|
||||
transform: Callable[[pd.Series], Awaitable[ItemType]],
|
||||
@ -88,9 +85,6 @@ async def derive_from_rows_asyncio_threads(
|
||||
)
|
||||
|
||||
|
||||
"""A module containing the derive_from_rows_async method."""
|
||||
|
||||
|
||||
async def derive_from_rows_asyncio(
|
||||
input: pd.DataFrame,
|
||||
transform: Callable[[pd.Series], Awaitable[ItemType]],
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing mock model provider definitions."""
|
||||
"""A module containing 'MockChatLLM' and 'MockEmbeddingLLM' models."""
|
||||
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from typing import Any
|
||||
|
||||
@ -10,9 +10,9 @@ import pytest
|
||||
from graphrag.config.enums import ChunkStrategyType
|
||||
from graphrag.index.operations.chunk_text.chunk_text import (
|
||||
_get_num_total,
|
||||
_load_strategy,
|
||||
_run_strategy,
|
||||
chunk_text,
|
||||
load_strategy,
|
||||
run_strategy,
|
||||
)
|
||||
from graphrag.index.operations.chunk_text.typing import (
|
||||
TextChunk,
|
||||
@ -36,7 +36,7 @@ def test_get_num_total_array():
|
||||
def test_load_strategy_tokens():
|
||||
strategy_type = ChunkStrategyType.tokens
|
||||
|
||||
strategy_loaded = load_strategy(strategy_type)
|
||||
strategy_loaded = _load_strategy(strategy_type)
|
||||
|
||||
assert strategy_loaded.__name__ == "run_tokens"
|
||||
|
||||
@ -44,7 +44,7 @@ def test_load_strategy_tokens():
|
||||
def test_load_strategy_sentence():
|
||||
strategy_type = ChunkStrategyType.sentence
|
||||
|
||||
strategy_loaded = load_strategy(strategy_type)
|
||||
strategy_loaded = _load_strategy(strategy_type)
|
||||
|
||||
assert strategy_loaded.__name__ == "run_sentences"
|
||||
|
||||
@ -55,7 +55,7 @@ def test_load_strategy_none():
|
||||
with pytest.raises(
|
||||
ValueError, match="Unknown strategy: <enum 'ChunkStrategyType'>"
|
||||
):
|
||||
load_strategy(strategy_type) # type: ignore
|
||||
_load_strategy(strategy_type) # type: ignore
|
||||
|
||||
|
||||
def test_run_strategy_str():
|
||||
@ -71,7 +71,7 @@ def test_run_strategy_str():
|
||||
)
|
||||
]
|
||||
|
||||
runned = run_strategy(strategy_mocked, input, config, tick)
|
||||
runned = _run_strategy(strategy_mocked, input, config, tick)
|
||||
assert runned == ["text test for run strategy"]
|
||||
|
||||
|
||||
@ -93,7 +93,7 @@ def test_run_strategy_arr_str():
|
||||
"use for strategy",
|
||||
]
|
||||
|
||||
runned = run_strategy(strategy_mocked, input, config, tick)
|
||||
runned = _run_strategy(strategy_mocked, input, config, tick)
|
||||
assert runned == expected
|
||||
|
||||
|
||||
@ -123,7 +123,7 @@ def test_run_strategy_arr_tuple():
|
||||
),
|
||||
]
|
||||
|
||||
runned = run_strategy(strategy_mocked, input, config, tick)
|
||||
runned = _run_strategy(strategy_mocked, input, config, tick)
|
||||
assert runned == expected
|
||||
|
||||
|
||||
@ -153,12 +153,12 @@ def test_run_strategy_arr_tuple_same_doc():
|
||||
),
|
||||
]
|
||||
|
||||
runned = run_strategy(strategy_mocked, input, config, tick)
|
||||
runned = _run_strategy(strategy_mocked, input, config, tick)
|
||||
assert runned == expected
|
||||
|
||||
|
||||
@mock.patch("graphrag.index.operations.chunk_text.chunk_text.load_strategy")
|
||||
@mock.patch("graphrag.index.operations.chunk_text.chunk_text.run_strategy")
|
||||
@mock.patch("graphrag.index.operations.chunk_text.chunk_text._load_strategy")
|
||||
@mock.patch("graphrag.index.operations.chunk_text.chunk_text._run_strategy")
|
||||
@mock.patch("graphrag.index.operations.chunk_text.chunk_text.progress_ticker")
|
||||
def test_chunk_text(mock_progress_ticker, mock_run_strategy, mock_load_strategy):
|
||||
input_data = pd.DataFrame({"name": ["The Shining"]})
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License
|
||||
import unittest
|
||||
|
||||
from graphrag.index.operations.extract_graph.extract_graph import run_extract_graph
|
||||
from graphrag.index.operations.extract_graph.extract_graph import _run_extract_graph
|
||||
from graphrag.prompts.index.extract_graph import GRAPH_EXTRACTION_PROMPT
|
||||
|
||||
from tests.unit.indexing.verbs.helpers.mock_llm import create_mock_llm
|
||||
@ -22,7 +22,7 @@ SIMPLE_EXTRACTION_RESPONSE = """
|
||||
|
||||
class TestRunChain(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_run_extract_graph_single_document_correct_entities_returned(self):
|
||||
entities_df, _ = await run_extract_graph(
|
||||
entities_df, _ = await _run_extract_graph(
|
||||
text="test_text",
|
||||
source_id="1",
|
||||
entity_types=["person"],
|
||||
@ -39,7 +39,7 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
|
||||
async def test_run_extract_graph_single_document_correct_edges_returned(self):
|
||||
_, relationships_df = await run_extract_graph(
|
||||
_, relationships_df = await _run_extract_graph(
|
||||
text="test_text",
|
||||
source_id="1",
|
||||
entity_types=["person"],
|
||||
@ -61,7 +61,7 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
|
||||
}
|
||||
|
||||
async def test_run_extract_graph_single_document_source_ids_mapped(self):
|
||||
entities_df, relationships_df = await run_extract_graph(
|
||||
entities_df, relationships_df = await _run_extract_graph(
|
||||
text="test_text",
|
||||
source_id="1",
|
||||
entity_types=["person"],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user