mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Merge from main
This commit is contained in:
parent
bdc170930a
commit
a399dde97b
13
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
13
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@ -4,6 +4,14 @@ title: "[Bug]: <title>"
|
||||
labels: ["bug", "triage"]
|
||||
|
||||
body:
|
||||
- type: checkboxes
|
||||
id: existingcheck
|
||||
attributes:
|
||||
label: Is there an existing issue for this?
|
||||
description: Please search to see if an issue already exists for the bug you encountered.
|
||||
options:
|
||||
- label: I have searched the existing issues
|
||||
- label: I have checked [#657](https://github.com/microsoft/graphrag/issues/657) to validate if my issue is covered by community support
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
@ -34,6 +42,11 @@ body:
|
||||
label: GraphRAG Config Used
|
||||
description: The GraphRAG configuration used for the run.
|
||||
placeholder: The settings.yaml content or GraphRAG configuration
|
||||
value: |
|
||||
```yaml
|
||||
# Paste your config here
|
||||
|
||||
```
|
||||
- type: textarea
|
||||
id: screenshotslogs
|
||||
attributes:
|
||||
|
||||
13
.github/ISSUE_TEMPLATE/general_issue.yml
vendored
13
.github/ISSUE_TEMPLATE/general_issue.yml
vendored
@ -4,6 +4,14 @@ title: "[Issue]: <title> "
|
||||
labels: ["triage"]
|
||||
|
||||
body:
|
||||
- type: checkboxes
|
||||
id: existingcheck
|
||||
attributes:
|
||||
label: Is there an existing issue for this?
|
||||
description: Please search to see if an issue already exists for the bug you encountered.
|
||||
options:
|
||||
- label: I have searched the existing issues
|
||||
- label: I have checked [#657](https://github.com/microsoft/graphrag/issues/657) to validate if my issue is covered by community support
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
@ -28,6 +36,11 @@ body:
|
||||
label: GraphRAG Config Used
|
||||
description: The GraphRAG configuration used for the run.
|
||||
placeholder: The settings.yaml content or GraphRAG configuration
|
||||
value: |
|
||||
```yaml
|
||||
# Paste your config here
|
||||
|
||||
```
|
||||
- type: textarea
|
||||
id: screenshotslogs
|
||||
attributes:
|
||||
|
||||
24
.github/workflows/issues-autoresolve.yml
vendored
Normal file
24
.github/workflows/issues-autoresolve.yml
vendored
Normal file
@ -0,0 +1,24 @@
|
||||
name: Close inactive issues
|
||||
on:
|
||||
schedule:
|
||||
- cron: "30 1 * * *"
|
||||
|
||||
jobs:
|
||||
close-issues:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/stale@v5
|
||||
with:
|
||||
days-before-issue-stale: 7
|
||||
days-before-issue-close: 5
|
||||
stale-issue-label: "stale"
|
||||
close-issue-label: "autoresolved"
|
||||
stale-issue-message: "This issue has been marked stale due to inactivity after repo maintainer or community member responses that request more information or suggest a solution. It will be closed after five additional days."
|
||||
close-issue-message: "This issue has been closed after being marked as stale for five days. Please reopen if needed."
|
||||
exempt-issue-label: "triage"
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
7
.github/workflows/python-ci.yml
vendored
7
.github/workflows/python-ci.yml
vendored
@ -21,7 +21,7 @@ jobs:
|
||||
python-ci:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10", "3.11", "3.12"]
|
||||
python-version: ["3.10", "3.11"] # add 3.12 once gensim supports it. TODO: watch this issue - https://github.com/piskvorky/gensim/issues/3510
|
||||
os: [ubuntu-latest, windows-latest]
|
||||
env:
|
||||
DEBUG: 1
|
||||
@ -79,7 +79,10 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: poetry self add setuptools && poetry run python -m pip install gensim && poetry install
|
||||
run: |
|
||||
poetry self add setuptools wheel
|
||||
poetry run python -m pip install gensim
|
||||
poetry install
|
||||
|
||||
- name: Check Semversioner
|
||||
run: |
|
||||
|
||||
6
.github/workflows/python-publish.yml
vendored
6
.github/workflows/python-publish.yml
vendored
@ -36,10 +36,16 @@ jobs:
|
||||
with:
|
||||
poetry-version: ${{ env.POETRY_VERSION }}
|
||||
|
||||
- name: Add poetry-dynamic-versioning plugin
|
||||
run: poetry self add "poetry-dynamic-versioning[plugin]"
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: poetry install
|
||||
|
||||
- name: Export Publication Version
|
||||
run: echo "version=`poetry version --short`" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Build Distributable
|
||||
shell: bash
|
||||
run: poetry build
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "minor",
|
||||
"description": "Add content-based KNN for selecting prompt tune few shot examples"
|
||||
}
|
||||
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "fix the organization parameter is ineffective during queries"
|
||||
}
|
||||
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "use binary io processing for all file io operations"
|
||||
}
|
||||
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "remove duplicate file read"
|
||||
}
|
||||
@ -3,4 +3,4 @@
|
||||
# @global-owner1 and @global-owner2 will be requested for
|
||||
# review when someone opens a pull request.
|
||||
* @microsoft/societal-resilience
|
||||
|
||||
* @microsoft/graphrag-core-team
|
||||
|
||||
@ -44,7 +44,7 @@ GraphRAG builds upon our prior [research](https://www.microsoft.com/en-us/workla
|
||||
|
||||
### Index
|
||||
|
||||
- Slice up an input corpus into a series of TextUnits, which act as analyzable units for the rest of the process, and provide fine-grained references ino our outputs.
|
||||
- Slice up an input corpus into a series of TextUnits, which act as analyzable units for the rest of the process, and provide fine-grained references in our outputs.
|
||||
- Extract all entities, relationships, and key claims from the TextUnits using an LLM.
|
||||
- Perform a hierarchical clustering of the graph using the [Leiden technique](https://arxiv.org/pdf/1810.08473.pdf). To see this visually, check out Figure 1 above. Each circle is an entity (e.g., a person, place, or organization), with the size representing the degree of the entity, and the color representing its community.
|
||||
- Generate summaries of each community and its constituents from the bottom-up. This aids in holistic understanding of the dataset.
|
||||
|
||||
@ -43,7 +43,9 @@ class ClaimExtractionConfig(LLMConfig):
|
||||
"type": ExtractClaimsStrategyType.graph_intelligence,
|
||||
"llm": self.llm.model_dump(),
|
||||
**self.parallelization.model_dump(),
|
||||
"extraction_prompt": (Path(root_dir) / self.prompt).read_text()
|
||||
"extraction_prompt": (Path(root_dir) / self.prompt)
|
||||
.read_bytes()
|
||||
.decode(encoding="utf-8")
|
||||
if self.prompt
|
||||
else None,
|
||||
"claim_description": self.description,
|
||||
|
||||
@ -38,7 +38,9 @@ class CommunityReportsConfig(LLMConfig):
|
||||
"type": CreateCommunityReportsStrategyType.graph_intelligence,
|
||||
"llm": self.llm.model_dump(),
|
||||
**self.parallelization.model_dump(),
|
||||
"extraction_prompt": (Path(root_dir) / self.prompt).read_text()
|
||||
"extraction_prompt": (Path(root_dir) / self.prompt)
|
||||
.read_bytes()
|
||||
.decode(encoding="utf-8")
|
||||
if self.prompt
|
||||
else None,
|
||||
"max_report_length": self.max_length,
|
||||
|
||||
@ -38,7 +38,9 @@ class EntityExtractionConfig(LLMConfig):
|
||||
"type": ExtractEntityStrategyType.graph_intelligence,
|
||||
"llm": self.llm.model_dump(),
|
||||
**self.parallelization.model_dump(),
|
||||
"extraction_prompt": (Path(root_dir) / self.prompt).read_text()
|
||||
"extraction_prompt": (Path(root_dir) / self.prompt)
|
||||
.read_bytes()
|
||||
.decode(encoding="utf-8")
|
||||
if self.prompt
|
||||
else None,
|
||||
"max_gleanings": self.max_gleanings,
|
||||
|
||||
@ -34,7 +34,9 @@ class SummarizeDescriptionsConfig(LLMConfig):
|
||||
"type": SummarizeStrategyType.graph_intelligence,
|
||||
"llm": self.llm.model_dump(),
|
||||
**self.parallelization.model_dump(),
|
||||
"summarize_prompt": (Path(root_dir) / self.prompt).read_text()
|
||||
"summarize_prompt": (Path(root_dir) / self.prompt)
|
||||
.read_bytes()
|
||||
.decode(encoding="utf-8")
|
||||
if self.prompt
|
||||
else None,
|
||||
"max_summary_length": self.max_length,
|
||||
|
||||
@ -185,11 +185,11 @@ def _initialize_project_at(path: str, reporter: ProgressReporter) -> None:
|
||||
|
||||
dotenv = root / ".env"
|
||||
if not dotenv.exists():
|
||||
with settings_yaml.open("w") as file:
|
||||
file.write(INIT_YAML)
|
||||
with settings_yaml.open("wb") as file:
|
||||
file.write(INIT_YAML.encode(encoding="utf-8", errors="strict"))
|
||||
|
||||
with dotenv.open("w") as file:
|
||||
file.write(INIT_DOTENV)
|
||||
with dotenv.open("wb") as file:
|
||||
file.write(INIT_DOTENV.encode(encoding="utf-8", errors="strict"))
|
||||
|
||||
prompts_dir = root / "prompts"
|
||||
if not prompts_dir.exists():
|
||||
@ -197,23 +197,29 @@ def _initialize_project_at(path: str, reporter: ProgressReporter) -> None:
|
||||
|
||||
entity_extraction = prompts_dir / "entity_extraction.txt"
|
||||
if not entity_extraction.exists():
|
||||
with entity_extraction.open("w") as file:
|
||||
file.write(GRAPH_EXTRACTION_PROMPT)
|
||||
with entity_extraction.open("wb") as file:
|
||||
file.write(
|
||||
GRAPH_EXTRACTION_PROMPT.encode(encoding="utf-8", errors="strict")
|
||||
)
|
||||
|
||||
summarize_descriptions = prompts_dir / "summarize_descriptions.txt"
|
||||
if not summarize_descriptions.exists():
|
||||
with summarize_descriptions.open("w") as file:
|
||||
file.write(SUMMARIZE_PROMPT)
|
||||
with summarize_descriptions.open("wb") as file:
|
||||
file.write(SUMMARIZE_PROMPT.encode(encoding="utf-8", errors="strict"))
|
||||
|
||||
claim_extraction = prompts_dir / "claim_extraction.txt"
|
||||
if not claim_extraction.exists():
|
||||
with claim_extraction.open("w") as file:
|
||||
file.write(CLAIM_EXTRACTION_PROMPT)
|
||||
with claim_extraction.open("wb") as file:
|
||||
file.write(
|
||||
CLAIM_EXTRACTION_PROMPT.encode(encoding="utf-8", errors="strict")
|
||||
)
|
||||
|
||||
community_report = prompts_dir / "community_report.txt"
|
||||
if not community_report.exists():
|
||||
with community_report.open("w") as file:
|
||||
file.write(COMMUNITY_REPORT_PROMPT)
|
||||
with community_report.open("wb") as file:
|
||||
file.write(
|
||||
COMMUNITY_REPORT_PROMPT.encode(encoding="utf-8", errors="strict")
|
||||
)
|
||||
|
||||
|
||||
def _create_default_config(
|
||||
@ -267,18 +273,18 @@ def _read_config_parameters(root: str, config: str | None, reporter: ProgressRep
|
||||
|
||||
if settings_yaml.exists():
|
||||
reporter.success(f"Reading settings from {settings_yaml}")
|
||||
with settings_yaml.open("r") as file:
|
||||
with settings_yaml.open("rb") as file:
|
||||
import yaml
|
||||
|
||||
data = yaml.safe_load(file)
|
||||
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
|
||||
return create_graphrag_config(data, root)
|
||||
|
||||
if settings_json.exists():
|
||||
reporter.success(f"Reading settings from {settings_json}")
|
||||
with settings_json.open("r") as file:
|
||||
with settings_json.open("rb") as file:
|
||||
import json
|
||||
|
||||
data = json.loads(file.read())
|
||||
data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
|
||||
return create_graphrag_config(data, root)
|
||||
|
||||
reporter.success("Reading settings from environment variables")
|
||||
|
||||
@ -26,8 +26,8 @@ def load_pipeline_config(config_or_path: str | PipelineConfig) -> PipelineConfig
|
||||
read_dotenv(str(Path(config_or_path).parent))
|
||||
|
||||
if config_or_path.endswith(".json"):
|
||||
with Path(config_or_path).open(encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
with Path(config_or_path).open("rb") as f:
|
||||
config = json.loads(f.read().decode(encoding="utf-8", errors="strict"))
|
||||
elif config_or_path.endswith((".yml", ".yaml")):
|
||||
config = _parse_yaml(config_or_path)
|
||||
else:
|
||||
@ -73,7 +73,7 @@ def _create_include_constructor():
|
||||
if filename.endswith((".yml", ".yaml")):
|
||||
return _parse_yaml(filename)
|
||||
|
||||
with Path(filename).open(encoding="utf-8") as f:
|
||||
return f.read()
|
||||
with Path(filename).open("rb") as f:
|
||||
return f.read().decode(encoding="utf-8", errors="strict")
|
||||
|
||||
return handle_include
|
||||
|
||||
@ -21,8 +21,8 @@ class FileWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
def __init__(self, directory: str):
|
||||
"""Create a new file-based workflow reporter."""
|
||||
Path(directory).mkdir(parents=True, exist_ok=True)
|
||||
self._out_stream = open( # noqa SIM115
|
||||
Path(directory) / "logs.json", "a", encoding="utf-8"
|
||||
self._out_stream = open( # noqa: PTH123, SIM115
|
||||
Path(directory) / "logs.json", "a", encoding="utf-8", errors="strict"
|
||||
)
|
||||
|
||||
def on_error(
|
||||
|
||||
@ -114,7 +114,9 @@ class FilePipelineStorage(PipelineStorage):
|
||||
write_type = "wb" if is_bytes else "w"
|
||||
encoding = None if is_bytes else encoding or self._encoding
|
||||
async with aiofiles.open(
|
||||
join_path(self._root_dir, key), cast(Any, write_type), encoding=encoding
|
||||
join_path(self._root_dir, key),
|
||||
cast(Any, write_type),
|
||||
encoding=encoding,
|
||||
) as f:
|
||||
await f.write(value)
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
from .dicts import dict_has_keys_with_types
|
||||
from .hashing import gen_md5_hash
|
||||
from .is_null import is_null
|
||||
from .json import clean_up_json
|
||||
from .load_graph import load_graph
|
||||
from .string import clean_str
|
||||
from .tokens import num_tokens_from_string, string_from_tokens
|
||||
@ -15,7 +14,6 @@ from .uuid import gen_uuid
|
||||
|
||||
__all__ = [
|
||||
"clean_str",
|
||||
"clean_up_json",
|
||||
"dict_has_keys_with_types",
|
||||
"gen_md5_hash",
|
||||
"gen_uuid",
|
||||
|
||||
@ -1,27 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""JSON cleaning and formatting utilities."""
|
||||
|
||||
|
||||
def clean_up_json(json_str: str):
|
||||
"""Clean up json string."""
|
||||
json_str = (
|
||||
json_str.replace("\\n", "")
|
||||
.replace("\n", "")
|
||||
.replace("\r", "")
|
||||
.replace('"[{', "[{")
|
||||
.replace('}]"', "}]")
|
||||
.replace("\\", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
# Remove JSON Markdown Frame
|
||||
if json_str.startswith("```json"):
|
||||
json_str = json_str[len("```json") :]
|
||||
if json_str.startswith("json"):
|
||||
json_str = json_str[len("json") :]
|
||||
if json_str.endswith("```"):
|
||||
json_str = json_str[: len(json_str) - len("```")]
|
||||
|
||||
return json_str
|
||||
@ -1,25 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""JSON cleaning and formatting utilities."""
|
||||
|
||||
|
||||
def clean_up_json(json_str: str) -> str:
|
||||
"""Clean up json string."""
|
||||
json_str = (
|
||||
json_str.replace("\\n", "")
|
||||
.replace("\n", "")
|
||||
.replace("\r", "")
|
||||
.replace('"[{', "[{")
|
||||
.replace('}]"', "}]")
|
||||
.replace("\\", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
# Remove JSON Markdown Frame
|
||||
if json_str.startswith("```json"):
|
||||
json_str = json_str[len("```json") :]
|
||||
if json_str.endswith("```"):
|
||||
json_str = json_str[: len(json_str) - len("```")]
|
||||
|
||||
return json_str
|
||||
@ -16,7 +16,6 @@ from graphrag.llm.types import (
|
||||
LLMOutput,
|
||||
)
|
||||
|
||||
from ._json import clean_up_json
|
||||
from ._prompts import JSON_CHECK_PROMPT
|
||||
from .openai_configuration import OpenAIConfiguration
|
||||
from .types import OpenAIClientTypes
|
||||
@ -104,11 +103,10 @@ class OpenAIChatLLM(BaseLLM[CompletionInput, CompletionOutput]):
|
||||
},
|
||||
)
|
||||
|
||||
raw_output = result.output or ""
|
||||
json_output = try_parse_json_object(raw_output)
|
||||
output, json_output = try_parse_json_object(result.output or "")
|
||||
|
||||
return LLMOutput[CompletionOutput](
|
||||
output=raw_output,
|
||||
output=output,
|
||||
json=json_output,
|
||||
history=result.history,
|
||||
)
|
||||
@ -119,24 +117,22 @@ class OpenAIChatLLM(BaseLLM[CompletionInput, CompletionOutput]):
|
||||
# Otherwise, clean up the output and try to parse it as json
|
||||
result = await self._invoke(input, **kwargs)
|
||||
history = result.history or []
|
||||
output = clean_up_json(result.output or "")
|
||||
try:
|
||||
json_output = try_parse_json_object(output)
|
||||
output, json_output = try_parse_json_object(result.output or "")
|
||||
if json_output:
|
||||
return LLMOutput[CompletionOutput](
|
||||
output=output, json=json_output, history=history
|
||||
output=result.output, json=json_output, history=history
|
||||
)
|
||||
except (TypeError, JSONDecodeError):
|
||||
log.warning("error parsing llm json, retrying")
|
||||
# If cleaned up json is unparsable, use the LLM to reformat it (may throw)
|
||||
result = await self._try_clean_json_with_llm(output, **kwargs)
|
||||
output = clean_up_json(result.output or "")
|
||||
json = try_parse_json_object(output)
|
||||
# if not return correct formatted json, retry
|
||||
log.warning("error parsing llm json, retrying")
|
||||
# If cleaned up json is unparsable, use the LLM to reformat it (may throw)
|
||||
result = await self._try_clean_json_with_llm(output, **kwargs)
|
||||
output, json_output = try_parse_json_object(result.output or "")
|
||||
|
||||
return LLMOutput[CompletionOutput](
|
||||
output=output,
|
||||
json=json,
|
||||
history=history,
|
||||
)
|
||||
return LLMOutput[CompletionOutput](
|
||||
output=output,
|
||||
json=json_output,
|
||||
history=history,
|
||||
)
|
||||
|
||||
async def _try_clean_json_with_llm(
|
||||
self, output: str, **kwargs: Unpack[LLMInput]
|
||||
|
||||
@ -2,13 +2,14 @@
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Utility functions for the OpenAI API."""
|
||||
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
from json_repair import repair_json
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
InternalServerError,
|
||||
@ -87,17 +88,52 @@ def get_completion_llm_args(
|
||||
}
|
||||
|
||||
|
||||
def try_parse_json_object(input: str) -> dict:
|
||||
"""Generate JSON-string output using best-attempt prompting & parsing techniques."""
|
||||
def try_parse_json_object(input: str) -> tuple[str, dict]:
|
||||
"""JSON cleaning and formatting utilities."""
|
||||
"""sometime, the llm return a json string with some extra description, this function will clean it up."""
|
||||
_pattern = r"\{(.*)\}"
|
||||
_match = re.search(_pattern, input)
|
||||
input = "{" + _match.group(1) + "}" if _match else input
|
||||
|
||||
"""Clean up json string."""
|
||||
input = (
|
||||
input.replace("{{","{")
|
||||
.replace("}}","}")
|
||||
.replace('"[{', "[{")
|
||||
.replace('}]"', "}]")
|
||||
.replace("\\", " ")
|
||||
.replace("\\n", " ")
|
||||
.replace("\n", " ")
|
||||
.replace("\r", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
# Remove JSON Markdown Frame
|
||||
if input.startswith("```json"):
|
||||
input = input[len("```json") :]
|
||||
if input.endswith("```"):
|
||||
input = input[: len(input) - len("```")]
|
||||
|
||||
try:
|
||||
result = json.loads(input)
|
||||
except json.JSONDecodeError:
|
||||
log.exception("error loading json, json=%s", input)
|
||||
raise
|
||||
|
||||
"""Fixup potentially malformed json string using json_repair."""
|
||||
input = str(repair_json(json_str=input, return_objects=False))
|
||||
|
||||
"""Generate JSON-string output using best-attempt prompting & parsing techniques."""
|
||||
try:
|
||||
result = json.loads(input)
|
||||
except json.JSONDecodeError:
|
||||
log.exception("error loading json, json=%s", input)
|
||||
return input, {}
|
||||
else:
|
||||
if not isinstance(result, dict):
|
||||
log.exception("not expected dict type. type=%s:", type(result))
|
||||
return input, {}
|
||||
return input, result
|
||||
else:
|
||||
if not isinstance(result, dict):
|
||||
raise TypeError
|
||||
return result
|
||||
return input, result
|
||||
|
||||
|
||||
def get_sleep_time_from_error(e: Any) -> float:
|
||||
|
||||
@ -10,7 +10,7 @@ from enum import Enum
|
||||
from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT
|
||||
from graphrag.prompt_tune.loader import MIN_CHUNK_SIZE
|
||||
|
||||
from .cli import fine_tune
|
||||
from .cli import prompt_tune
|
||||
|
||||
|
||||
class DocSelectionType(Enum):
|
||||
@ -19,6 +19,7 @@ class DocSelectionType(Enum):
|
||||
ALL = "all"
|
||||
RANDOM = "random"
|
||||
TOP = "top"
|
||||
AUTO = "auto"
|
||||
|
||||
def __str__(self):
|
||||
"""Return the string representation of the enum value."""
|
||||
@ -46,13 +47,29 @@ if __name__ == "__main__":
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
help="The method to select documents, one of: all, random or top",
|
||||
help="The method to select documents, one of: all, random, top or auto",
|
||||
required=False,
|
||||
type=DocSelectionType,
|
||||
choices=list(DocSelectionType),
|
||||
default=DocSelectionType.RANDOM,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--n_subset_max",
|
||||
help="The number of text chunks to embed when using auto selection method",
|
||||
required=False,
|
||||
type=int,
|
||||
default=300,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--k",
|
||||
help="The maximum number of documents to select from each centroid when using auto selection method",
|
||||
required=False,
|
||||
type=int,
|
||||
default=15,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
help="The limit of files to load when doing random or top selection",
|
||||
@ -69,6 +86,14 @@ if __name__ == "__main__":
|
||||
default=MAX_TOKEN_COUNT,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min-examples-required",
|
||||
help="The minimum number of examples required in entity extraction prompt",
|
||||
type=int,
|
||||
required=False,
|
||||
default=2,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--chunk-size",
|
||||
help="Max token count for prompt generation",
|
||||
@ -106,7 +131,7 @@ if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
loop.run_until_complete(
|
||||
fine_tune(
|
||||
prompt_tune(
|
||||
args.root,
|
||||
args.domain,
|
||||
str(args.method),
|
||||
@ -116,5 +141,8 @@ if __name__ == "__main__":
|
||||
args.language,
|
||||
args.no_entity_types,
|
||||
args.output,
|
||||
args.n_subset_max,
|
||||
args.k,
|
||||
args.min_examples_required,
|
||||
)
|
||||
)
|
||||
|
||||
@ -32,7 +32,7 @@ from graphrag.prompt_tune.loader import (
|
||||
)
|
||||
|
||||
|
||||
async def fine_tune(
|
||||
async def prompt_tune(
|
||||
root: str,
|
||||
domain: str,
|
||||
select: str = "random",
|
||||
@ -42,8 +42,11 @@ async def fine_tune(
|
||||
language: str | None = None,
|
||||
skip_entity_types: bool = False,
|
||||
output: str = "prompts",
|
||||
n_subset_max: int = 300,
|
||||
k: int = 15,
|
||||
min_examples_required: int = 2,
|
||||
):
|
||||
"""Fine tune the model.
|
||||
"""Prompt tune the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -55,11 +58,13 @@ async def fine_tune(
|
||||
- chunk_size: The chunk token size to use.
|
||||
- skip_entity_types: Skip generating entity types.
|
||||
- output: The output folder to store the prompts.
|
||||
- n_subset_max: The number of text chunks to embed when using auto selection method.
|
||||
- k: The number of documents to select when using auto selection method.
|
||||
"""
|
||||
reporter = PrintProgressReporter("")
|
||||
config = read_config_parameters(root, reporter)
|
||||
|
||||
await fine_tune_with_config(
|
||||
await prompt_tune_with_config(
|
||||
root,
|
||||
config,
|
||||
domain,
|
||||
@ -71,10 +76,13 @@ async def fine_tune(
|
||||
skip_entity_types,
|
||||
output,
|
||||
reporter,
|
||||
n_subset_max,
|
||||
k,
|
||||
min_examples_required,
|
||||
)
|
||||
|
||||
|
||||
async def fine_tune_with_config(
|
||||
async def prompt_tune_with_config(
|
||||
root: str,
|
||||
config: GraphRagConfig,
|
||||
domain: str,
|
||||
@ -86,8 +94,11 @@ async def fine_tune_with_config(
|
||||
skip_entity_types: bool = False,
|
||||
output: str = "prompts",
|
||||
reporter: ProgressReporter | None = None,
|
||||
n_subset_max: int = 300,
|
||||
k: int = 15,
|
||||
min_examples_required: int = 2,
|
||||
):
|
||||
"""Fine tune the model with a configuration.
|
||||
"""Prompt tune the model with a configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -101,6 +112,8 @@ async def fine_tune_with_config(
|
||||
- skip_entity_types: Skip generating entity types.
|
||||
- output: The output folder to store the prompts.
|
||||
- reporter: The progress reporter.
|
||||
- n_subset_max: The number of text chunks to embed when using auto selection method.
|
||||
- k: The number of documents to select when using auto selection method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -118,11 +131,13 @@ async def fine_tune_with_config(
|
||||
select_method=select,
|
||||
reporter=reporter,
|
||||
chunk_size=chunk_size,
|
||||
n_subset_max=n_subset_max,
|
||||
k=k,
|
||||
)
|
||||
|
||||
# Create LLM from config
|
||||
llm = load_llm(
|
||||
"fine_tuning",
|
||||
"prompt_tuning",
|
||||
config.llm.type,
|
||||
NoopVerbCallbacks(),
|
||||
None,
|
||||
@ -139,6 +154,7 @@ async def fine_tune_with_config(
|
||||
language,
|
||||
max_tokens,
|
||||
skip_entity_types,
|
||||
min_examples_required,
|
||||
)
|
||||
|
||||
|
||||
@ -152,6 +168,7 @@ async def generate_indexing_prompts(
|
||||
language: str | None = None,
|
||||
max_tokens: int = MAX_TOKEN_COUNT,
|
||||
skip_entity_types: bool = False,
|
||||
min_examples_required: int = 2,
|
||||
):
|
||||
"""Generate indexing prompts.
|
||||
|
||||
@ -165,6 +182,7 @@ async def generate_indexing_prompts(
|
||||
- domain: The domain to map the input documents to.
|
||||
- max_tokens: The maximum number of tokens to use on entity extraction prompts
|
||||
- skip_entity_types: Skip generating entity types.
|
||||
- min_examples_required: The minimum number of examples required for entity extraction prompts.
|
||||
"""
|
||||
if not domain:
|
||||
reporter.info("Generating domain...")
|
||||
@ -221,6 +239,7 @@ async def generate_indexing_prompts(
|
||||
output_path=output_path,
|
||||
encoding_model=config.encoding_model,
|
||||
max_token_count=max_tokens,
|
||||
min_examples_required=min_examples_required,
|
||||
)
|
||||
reporter.info(f"Generated entity extraction prompt, stored in folder {output_path}")
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ def create_community_summarization_prompt(
|
||||
|
||||
output_path = output_path / COMMUNITY_SUMMARIZATION_FILENAME
|
||||
# Write file to output path
|
||||
with output_path.open("w") as file:
|
||||
file.write(prompt)
|
||||
with output_path.open("wb") as file:
|
||||
file.write(prompt.encode(encoding="utf-8", errors="strict"))
|
||||
|
||||
return prompt
|
||||
|
||||
@ -27,6 +27,7 @@ def create_entity_extraction_prompt(
|
||||
encoding_model: str = defs.ENCODING_MODEL,
|
||||
json_mode: bool = False,
|
||||
output_path: Path | None = None,
|
||||
min_examples_required: int = 2,
|
||||
) -> str:
|
||||
"""
|
||||
Create a prompt for entity extraction.
|
||||
@ -41,6 +42,7 @@ def create_entity_extraction_prompt(
|
||||
- max_token_count (int): The maximum number of tokens to use for the prompt
|
||||
- json_mode (bool): Whether to use JSON mode for the prompt. Default is False
|
||||
- output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None.
|
||||
- min_examples_required (int): The minimum number of examples required. Default is 2.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -79,8 +81,8 @@ def create_entity_extraction_prompt(
|
||||
|
||||
example_tokens = num_tokens_from_string(example_formatted, model=encoding_model)
|
||||
|
||||
# Squeeze in at least one example
|
||||
if i > 0 and example_tokens > tokens_left:
|
||||
# Ensure at least three examples are included
|
||||
if i >= min_examples_required and example_tokens > tokens_left:
|
||||
break
|
||||
|
||||
examples_prompt += example_formatted
|
||||
@ -99,7 +101,7 @@ def create_entity_extraction_prompt(
|
||||
|
||||
output_path = output_path / ENTITY_EXTRACTION_FILENAME
|
||||
# Write file to output path
|
||||
with output_path.open("w") as file:
|
||||
file.write(prompt)
|
||||
with output_path.open("wb") as file:
|
||||
file.write(prompt.encode(encoding="utf-8", errors="strict"))
|
||||
|
||||
return prompt
|
||||
|
||||
@ -30,7 +30,7 @@ def create_entity_summarization_prompt(
|
||||
|
||||
output_path = output_path / ENTITY_SUMMARIZATION_FILENAME
|
||||
# Write file to output path
|
||||
with output_path.open("w") as file:
|
||||
file.write(prompt)
|
||||
with output_path.open("wb") as file:
|
||||
file.write(prompt.encode(encoding="utf-8", errors="strict"))
|
||||
|
||||
return prompt
|
||||
|
||||
@ -25,18 +25,18 @@ def read_config_parameters(root: str, reporter: ProgressReporter):
|
||||
|
||||
if settings_yaml.exists():
|
||||
reporter.info(f"Reading settings from {settings_yaml}")
|
||||
with settings_yaml.open("r") as file:
|
||||
with settings_yaml.open("rb") as file:
|
||||
import yaml
|
||||
|
||||
data = yaml.safe_load(file)
|
||||
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
|
||||
return create_graphrag_config(data, root)
|
||||
|
||||
if settings_json.exists():
|
||||
reporter.info(f"Reading settings from {settings_json}")
|
||||
with settings_json.open("r") as file:
|
||||
with settings_json.open("rb") as file:
|
||||
import json
|
||||
|
||||
data = json.loads(file.read())
|
||||
data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
|
||||
return create_graphrag_config(data, root)
|
||||
|
||||
reporter.info("Reading settings from environment variables")
|
||||
|
||||
@ -5,16 +5,45 @@
|
||||
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datashaper import NoopVerbCallbacks, TableContainer, VerbInput
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.input import load_input
|
||||
from graphrag.index.llm import load_llm_embeddings
|
||||
from graphrag.index.progress.types import ProgressReporter
|
||||
from graphrag.index.verbs import chunk
|
||||
from graphrag.llm.types.llm_types import EmbeddingLLM
|
||||
|
||||
MIN_CHUNK_SIZE = 200
|
||||
MIN_CHUNK_OVERLAP = 0
|
||||
MIN_CHUNK_SIZE = 200
|
||||
N_SUBSET_MAX = 300
|
||||
K = 15
|
||||
|
||||
|
||||
async def _embed_chunks(
|
||||
text_chunks: pd.DataFrame,
|
||||
embedding_llm: EmbeddingLLM,
|
||||
n_subset_max: int = N_SUBSET_MAX,
|
||||
) -> tuple[pd.DataFrame, np.ndarray]:
|
||||
"""Convert text chunks into dense text embeddings."""
|
||||
sampled_text_chunks = text_chunks.sample(n=min(n_subset_max, len(text_chunks)))
|
||||
embeddings = await embedding_llm(sampled_text_chunks["chunks"].tolist())
|
||||
return text_chunks, np.array(embeddings.output)
|
||||
|
||||
|
||||
def _sample_chunks_from_embeddings(
|
||||
text_chunks: pd.DataFrame,
|
||||
embeddings,
|
||||
k: int = K,
|
||||
) -> pd.DataFrame:
|
||||
"""Sample text chunks from embeddings."""
|
||||
center = np.mean(embeddings, axis=0)
|
||||
distances = np.linalg.norm(embeddings - center, axis=1)
|
||||
nearest_indices = np.argsort(distances)[:k]
|
||||
|
||||
return text_chunks.iloc[nearest_indices]
|
||||
|
||||
|
||||
async def load_docs_in_chunks(
|
||||
@ -24,6 +53,8 @@ async def load_docs_in_chunks(
|
||||
limit: int,
|
||||
reporter: ProgressReporter,
|
||||
chunk_size: int = MIN_CHUNK_SIZE,
|
||||
n_subset_max: int = N_SUBSET_MAX,
|
||||
k: int = K,
|
||||
) -> list[str]:
|
||||
"""Load docs into chunks for generating prompts."""
|
||||
dataset = await load_input(config.input, reporter, root)
|
||||
@ -57,6 +88,22 @@ async def load_docs_in_chunks(
|
||||
chunks_df = chunks_df[:limit]
|
||||
elif select_method == "random":
|
||||
chunks_df = chunks_df.sample(n=limit)
|
||||
elif select_method == "auto":
|
||||
if k is None or k <= 0:
|
||||
msg = "k must be an integer > 0"
|
||||
raise ValueError(msg)
|
||||
embedding_llm = load_llm_embeddings(
|
||||
name="prompt_tuning_embeddings",
|
||||
llm_type=config.embeddings.resolved_strategy()["llm"]["type"],
|
||||
callbacks=NoopVerbCallbacks(),
|
||||
cache=None,
|
||||
llm_config=config.embeddings.resolved_strategy()["llm"],
|
||||
)
|
||||
|
||||
chunks_df, embeddings = await _embed_chunks(
|
||||
chunks_df, embedding_llm, n_subset_max=n_subset_max
|
||||
)
|
||||
chunks_df = _sample_chunks_from_embeddings(chunks_df, embeddings, k=k)
|
||||
|
||||
# Convert the dataset to list form, so we have a list of documents
|
||||
return chunks_df["chunks"].tolist()
|
||||
|
||||
@ -110,7 +110,6 @@ def run_local_search(
|
||||
final_relationships = pd.read_parquet(
|
||||
data_path / "create_final_relationships.parquet"
|
||||
)
|
||||
final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet")
|
||||
final_entities = pd.read_parquet(data_path / "create_final_entities.parquet")
|
||||
final_covariates_path = data_path / "create_final_covariates.parquet"
|
||||
final_covariates = (
|
||||
@ -194,18 +193,20 @@ def _read_config_parameters(root: str):
|
||||
|
||||
if settings_yaml.exists():
|
||||
reporter.info(f"Reading settings from {settings_yaml}")
|
||||
with settings_yaml.open("r") as file:
|
||||
with settings_yaml.open(
|
||||
"rb",
|
||||
) as file:
|
||||
import yaml
|
||||
|
||||
data = yaml.safe_load(file)
|
||||
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
|
||||
return create_graphrag_config(data, root)
|
||||
|
||||
if settings_json.exists():
|
||||
reporter.info(f"Reading settings from {settings_json}")
|
||||
with settings_json.open("r") as file:
|
||||
with settings_json.open("rb") as file:
|
||||
import json
|
||||
|
||||
data = json.loads(file.read())
|
||||
data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
|
||||
return create_graphrag_config(data, root)
|
||||
|
||||
reporter.info("Reading settings from environment variables")
|
||||
|
||||
@ -58,6 +58,7 @@ def get_llm(config: GraphRagConfig) -> ChatOpenAI:
|
||||
else None
|
||||
),
|
||||
api_base=config.llm.api_base,
|
||||
organization=config.llm.organization,
|
||||
model=config.llm.model,
|
||||
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
|
||||
deployment_name=config.llm.deployment_name,
|
||||
@ -89,6 +90,7 @@ def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
|
||||
else None
|
||||
),
|
||||
api_base=config.embeddings.llm.api_base,
|
||||
organization=config.llm.organization,
|
||||
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
|
||||
model=config.embeddings.llm.model,
|
||||
deployment_name=config.embeddings.llm.deployment_name,
|
||||
|
||||
@ -13,7 +13,7 @@ from typing import Any
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
|
||||
from graphrag.index.utils.json import clean_up_json
|
||||
from graphrag.llm.openai.utils import try_parse_json_object
|
||||
from graphrag.query.context_builder.builders import GlobalContextBuilder
|
||||
from graphrag.query.context_builder.conversation_history import (
|
||||
ConversationHistory,
|
||||
@ -32,7 +32,7 @@ from graphrag.query.structured_search.global_search.reduce_system_prompt import
|
||||
NO_DATA_ANSWER,
|
||||
REDUCE_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
from graphrag.llm.openai.utils import try_parse_json_object
|
||||
DEFAULT_MAP_LLM_PARAMS = {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0.0,
|
||||
@ -188,7 +188,6 @@ class GlobalSearch(BaseSearch):
|
||||
processed_response = self.parse_search_response(search_response)
|
||||
except ValueError:
|
||||
# Clean up and retry parse
|
||||
search_response = clean_up_json(search_response)
|
||||
try:
|
||||
# parse search response json
|
||||
processed_response = self.parse_search_response(search_response)
|
||||
@ -229,6 +228,10 @@ class GlobalSearch(BaseSearch):
|
||||
list[dict[str, Any]]
|
||||
A list of key points, each key point is a dictionary with "answer" and "score" keys
|
||||
"""
|
||||
search_response,_j = try_parse_json_object(search_response)
|
||||
if _j =={}:
|
||||
return [{"answer":"not avaliable","score": 0}]
|
||||
|
||||
parsed_elements = json.loads(search_response)["points"]
|
||||
return [
|
||||
{
|
||||
|
||||
1483
poetry.lock
generated
1483
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -85,6 +85,7 @@ typing-extensions = "^4.12.2"
|
||||
#Azure
|
||||
azure-storage-blob = "^12.19.0"
|
||||
azure-identity = "^1.17.1"
|
||||
json-repair = "^0.25.3"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
coverage = "^7.6.0"
|
||||
|
||||
@ -38,8 +38,7 @@ def _load_fixtures():
|
||||
continue
|
||||
|
||||
config_file = fixtures_path / subfolder / "config.json"
|
||||
with config_file.open() as f:
|
||||
params.append((subfolder, json.load(f)))
|
||||
params.append((subfolder, json.loads(config_file.read_bytes().decode("utf-8"))))
|
||||
|
||||
return params
|
||||
|
||||
@ -104,8 +103,7 @@ async def prepare_azurite_data(input_path: str, azure: dict) -> Callable[[], Non
|
||||
csv_files = list((root / "input").glob("*.csv"))
|
||||
data_files = txt_files + csv_files
|
||||
for data_file in data_files:
|
||||
with data_file.open(encoding="utf8") as f:
|
||||
text = f.read()
|
||||
text = data_file.read_bytes().decode("utf-8")
|
||||
file_path = (
|
||||
str(Path(input_base_dir) / data_file.name)
|
||||
if input_base_dir
|
||||
@ -166,8 +164,7 @@ class TestIndexer:
|
||||
assert artifacts.exists(), "artifact folder does not exist"
|
||||
|
||||
# Check stats for all workflow
|
||||
with (artifacts / "stats.json").open() as f:
|
||||
stats = json.load(f)
|
||||
stats = json.loads((artifacts / "stats.json").read_bytes().decode("utf-8"))
|
||||
|
||||
# Check all workflows run
|
||||
expected_workflows = set(workflow_config.keys())
|
||||
|
||||
Loading…
Reference in New Issue
Block a user