Merge from main

This commit is contained in:
Alonso Guevara 2024-08-22 10:36:41 -06:00
parent bdc170930a
commit a399dde97b
37 changed files with 1137 additions and 811 deletions

View File

@ -4,6 +4,14 @@ title: "[Bug]: <title>"
labels: ["bug", "triage"]
body:
- type: checkboxes
id: existingcheck
attributes:
label: Is there an existing issue for this?
description: Please search to see if an issue already exists for the bug you encountered.
options:
- label: I have searched the existing issues
- label: I have checked [#657](https://github.com/microsoft/graphrag/issues/657) to validate if my issue is covered by community support
- type: textarea
id: description
attributes:
@ -34,6 +42,11 @@ body:
label: GraphRAG Config Used
description: The GraphRAG configuration used for the run.
placeholder: The settings.yaml content or GraphRAG configuration
value: |
```yaml
# Paste your config here
```
- type: textarea
id: screenshotslogs
attributes:

View File

@ -4,6 +4,14 @@ title: "[Issue]: <title> "
labels: ["triage"]
body:
- type: checkboxes
id: existingcheck
attributes:
label: Is there an existing issue for this?
description: Please search to see if an issue already exists for the bug you encountered.
options:
- label: I have searched the existing issues
- label: I have checked [#657](https://github.com/microsoft/graphrag/issues/657) to validate if my issue is covered by community support
- type: textarea
id: description
attributes:
@ -28,6 +36,11 @@ body:
label: GraphRAG Config Used
description: The GraphRAG configuration used for the run.
placeholder: The settings.yaml content or GraphRAG configuration
value: |
```yaml
# Paste your config here
```
- type: textarea
id: screenshotslogs
attributes:

View File

@ -0,0 +1,24 @@
name: Close inactive issues
on:
schedule:
- cron: "30 1 * * *"
jobs:
close-issues:
runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v5
with:
days-before-issue-stale: 7
days-before-issue-close: 5
stale-issue-label: "stale"
close-issue-label: "autoresolved"
stale-issue-message: "This issue has been marked stale due to inactivity after repo maintainer or community member responses that request more information or suggest a solution. It will be closed after five additional days."
close-issue-message: "This issue has been closed after being marked as stale for five days. Please reopen if needed."
exempt-issue-label: "triage"
days-before-pr-stale: -1
days-before-pr-close: -1
repo-token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -21,7 +21,7 @@ jobs:
python-ci:
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11"] # add 3.12 once gensim supports it. TODO: watch this issue - https://github.com/piskvorky/gensim/issues/3510
os: [ubuntu-latest, windows-latest]
env:
DEBUG: 1
@ -79,7 +79,10 @@ jobs:
- name: Install dependencies
shell: bash
run: poetry self add setuptools && poetry run python -m pip install gensim && poetry install
run: |
poetry self add setuptools wheel
poetry run python -m pip install gensim
poetry install
- name: Check Semversioner
run: |

View File

@ -36,10 +36,16 @@ jobs:
with:
poetry-version: ${{ env.POETRY_VERSION }}
- name: Add poetry-dynamic-versioning plugin
run: poetry self add "poetry-dynamic-versioning[plugin]"
- name: Install dependencies
shell: bash
run: poetry install
- name: Export Publication Version
run: echo "version=`poetry version --short`" >> $GITHUB_OUTPUT
- name: Build Distributable
shell: bash
run: poetry build

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Add content-based KNN for selecting prompt tune few shot examples"
}

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "fix the organization parameter is ineffective during queries"
}

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "use binary io processing for all file io operations"
}

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "remove duplicate file read"
}

View File

@ -3,4 +3,4 @@
# @global-owner1 and @global-owner2 will be requested for
# review when someone opens a pull request.
* @microsoft/societal-resilience
* @microsoft/graphrag-core-team

View File

@ -44,7 +44,7 @@ GraphRAG builds upon our prior [research](https://www.microsoft.com/en-us/workla
### Index
- Slice up an input corpus into a series of TextUnits, which act as analyzable units for the rest of the process, and provide fine-grained references ino our outputs.
- Slice up an input corpus into a series of TextUnits, which act as analyzable units for the rest of the process, and provide fine-grained references in our outputs.
- Extract all entities, relationships, and key claims from the TextUnits using an LLM.
- Perform a hierarchical clustering of the graph using the [Leiden technique](https://arxiv.org/pdf/1810.08473.pdf). To see this visually, check out Figure 1 above. Each circle is an entity (e.g., a person, place, or organization), with the size representing the degree of the entity, and the color representing its community.
- Generate summaries of each community and its constituents from the bottom-up. This aids in holistic understanding of the dataset.

View File

@ -43,7 +43,9 @@ class ClaimExtractionConfig(LLMConfig):
"type": ExtractClaimsStrategyType.graph_intelligence,
"llm": self.llm.model_dump(),
**self.parallelization.model_dump(),
"extraction_prompt": (Path(root_dir) / self.prompt).read_text()
"extraction_prompt": (Path(root_dir) / self.prompt)
.read_bytes()
.decode(encoding="utf-8")
if self.prompt
else None,
"claim_description": self.description,

View File

@ -38,7 +38,9 @@ class CommunityReportsConfig(LLMConfig):
"type": CreateCommunityReportsStrategyType.graph_intelligence,
"llm": self.llm.model_dump(),
**self.parallelization.model_dump(),
"extraction_prompt": (Path(root_dir) / self.prompt).read_text()
"extraction_prompt": (Path(root_dir) / self.prompt)
.read_bytes()
.decode(encoding="utf-8")
if self.prompt
else None,
"max_report_length": self.max_length,

View File

@ -38,7 +38,9 @@ class EntityExtractionConfig(LLMConfig):
"type": ExtractEntityStrategyType.graph_intelligence,
"llm": self.llm.model_dump(),
**self.parallelization.model_dump(),
"extraction_prompt": (Path(root_dir) / self.prompt).read_text()
"extraction_prompt": (Path(root_dir) / self.prompt)
.read_bytes()
.decode(encoding="utf-8")
if self.prompt
else None,
"max_gleanings": self.max_gleanings,

View File

@ -34,7 +34,9 @@ class SummarizeDescriptionsConfig(LLMConfig):
"type": SummarizeStrategyType.graph_intelligence,
"llm": self.llm.model_dump(),
**self.parallelization.model_dump(),
"summarize_prompt": (Path(root_dir) / self.prompt).read_text()
"summarize_prompt": (Path(root_dir) / self.prompt)
.read_bytes()
.decode(encoding="utf-8")
if self.prompt
else None,
"max_summary_length": self.max_length,

View File

@ -185,11 +185,11 @@ def _initialize_project_at(path: str, reporter: ProgressReporter) -> None:
dotenv = root / ".env"
if not dotenv.exists():
with settings_yaml.open("w") as file:
file.write(INIT_YAML)
with settings_yaml.open("wb") as file:
file.write(INIT_YAML.encode(encoding="utf-8", errors="strict"))
with dotenv.open("w") as file:
file.write(INIT_DOTENV)
with dotenv.open("wb") as file:
file.write(INIT_DOTENV.encode(encoding="utf-8", errors="strict"))
prompts_dir = root / "prompts"
if not prompts_dir.exists():
@ -197,23 +197,29 @@ def _initialize_project_at(path: str, reporter: ProgressReporter) -> None:
entity_extraction = prompts_dir / "entity_extraction.txt"
if not entity_extraction.exists():
with entity_extraction.open("w") as file:
file.write(GRAPH_EXTRACTION_PROMPT)
with entity_extraction.open("wb") as file:
file.write(
GRAPH_EXTRACTION_PROMPT.encode(encoding="utf-8", errors="strict")
)
summarize_descriptions = prompts_dir / "summarize_descriptions.txt"
if not summarize_descriptions.exists():
with summarize_descriptions.open("w") as file:
file.write(SUMMARIZE_PROMPT)
with summarize_descriptions.open("wb") as file:
file.write(SUMMARIZE_PROMPT.encode(encoding="utf-8", errors="strict"))
claim_extraction = prompts_dir / "claim_extraction.txt"
if not claim_extraction.exists():
with claim_extraction.open("w") as file:
file.write(CLAIM_EXTRACTION_PROMPT)
with claim_extraction.open("wb") as file:
file.write(
CLAIM_EXTRACTION_PROMPT.encode(encoding="utf-8", errors="strict")
)
community_report = prompts_dir / "community_report.txt"
if not community_report.exists():
with community_report.open("w") as file:
file.write(COMMUNITY_REPORT_PROMPT)
with community_report.open("wb") as file:
file.write(
COMMUNITY_REPORT_PROMPT.encode(encoding="utf-8", errors="strict")
)
def _create_default_config(
@ -267,18 +273,18 @@ def _read_config_parameters(root: str, config: str | None, reporter: ProgressRep
if settings_yaml.exists():
reporter.success(f"Reading settings from {settings_yaml}")
with settings_yaml.open("r") as file:
with settings_yaml.open("rb") as file:
import yaml
data = yaml.safe_load(file)
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root)
if settings_json.exists():
reporter.success(f"Reading settings from {settings_json}")
with settings_json.open("r") as file:
with settings_json.open("rb") as file:
import json
data = json.loads(file.read())
data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root)
reporter.success("Reading settings from environment variables")

View File

@ -26,8 +26,8 @@ def load_pipeline_config(config_or_path: str | PipelineConfig) -> PipelineConfig
read_dotenv(str(Path(config_or_path).parent))
if config_or_path.endswith(".json"):
with Path(config_or_path).open(encoding="utf-8") as f:
config = json.load(f)
with Path(config_or_path).open("rb") as f:
config = json.loads(f.read().decode(encoding="utf-8", errors="strict"))
elif config_or_path.endswith((".yml", ".yaml")):
config = _parse_yaml(config_or_path)
else:
@ -73,7 +73,7 @@ def _create_include_constructor():
if filename.endswith((".yml", ".yaml")):
return _parse_yaml(filename)
with Path(filename).open(encoding="utf-8") as f:
return f.read()
with Path(filename).open("rb") as f:
return f.read().decode(encoding="utf-8", errors="strict")
return handle_include

View File

@ -21,8 +21,8 @@ class FileWorkflowCallbacks(NoopWorkflowCallbacks):
def __init__(self, directory: str):
"""Create a new file-based workflow reporter."""
Path(directory).mkdir(parents=True, exist_ok=True)
self._out_stream = open( # noqa SIM115
Path(directory) / "logs.json", "a", encoding="utf-8"
self._out_stream = open( # noqa: PTH123, SIM115
Path(directory) / "logs.json", "a", encoding="utf-8", errors="strict"
)
def on_error(

View File

@ -114,7 +114,9 @@ class FilePipelineStorage(PipelineStorage):
write_type = "wb" if is_bytes else "w"
encoding = None if is_bytes else encoding or self._encoding
async with aiofiles.open(
join_path(self._root_dir, key), cast(Any, write_type), encoding=encoding
join_path(self._root_dir, key),
cast(Any, write_type),
encoding=encoding,
) as f:
await f.write(value)

View File

@ -6,7 +6,6 @@
from .dicts import dict_has_keys_with_types
from .hashing import gen_md5_hash
from .is_null import is_null
from .json import clean_up_json
from .load_graph import load_graph
from .string import clean_str
from .tokens import num_tokens_from_string, string_from_tokens
@ -15,7 +14,6 @@ from .uuid import gen_uuid
__all__ = [
"clean_str",
"clean_up_json",
"dict_has_keys_with_types",
"gen_md5_hash",
"gen_uuid",

View File

@ -1,27 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""JSON cleaning and formatting utilities."""
def clean_up_json(json_str: str):
"""Clean up json string."""
json_str = (
json_str.replace("\\n", "")
.replace("\n", "")
.replace("\r", "")
.replace('"[{', "[{")
.replace('}]"', "}]")
.replace("\\", "")
.strip()
)
# Remove JSON Markdown Frame
if json_str.startswith("```json"):
json_str = json_str[len("```json") :]
if json_str.startswith("json"):
json_str = json_str[len("json") :]
if json_str.endswith("```"):
json_str = json_str[: len(json_str) - len("```")]
return json_str

View File

@ -1,25 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""JSON cleaning and formatting utilities."""
def clean_up_json(json_str: str) -> str:
"""Clean up json string."""
json_str = (
json_str.replace("\\n", "")
.replace("\n", "")
.replace("\r", "")
.replace('"[{', "[{")
.replace('}]"', "}]")
.replace("\\", "")
.strip()
)
# Remove JSON Markdown Frame
if json_str.startswith("```json"):
json_str = json_str[len("```json") :]
if json_str.endswith("```"):
json_str = json_str[: len(json_str) - len("```")]
return json_str

View File

@ -16,7 +16,6 @@ from graphrag.llm.types import (
LLMOutput,
)
from ._json import clean_up_json
from ._prompts import JSON_CHECK_PROMPT
from .openai_configuration import OpenAIConfiguration
from .types import OpenAIClientTypes
@ -104,11 +103,10 @@ class OpenAIChatLLM(BaseLLM[CompletionInput, CompletionOutput]):
},
)
raw_output = result.output or ""
json_output = try_parse_json_object(raw_output)
output, json_output = try_parse_json_object(result.output or "")
return LLMOutput[CompletionOutput](
output=raw_output,
output=output,
json=json_output,
history=result.history,
)
@ -119,24 +117,22 @@ class OpenAIChatLLM(BaseLLM[CompletionInput, CompletionOutput]):
# Otherwise, clean up the output and try to parse it as json
result = await self._invoke(input, **kwargs)
history = result.history or []
output = clean_up_json(result.output or "")
try:
json_output = try_parse_json_object(output)
output, json_output = try_parse_json_object(result.output or "")
if json_output:
return LLMOutput[CompletionOutput](
output=output, json=json_output, history=history
output=result.output, json=json_output, history=history
)
except (TypeError, JSONDecodeError):
log.warning("error parsing llm json, retrying")
# If cleaned up json is unparsable, use the LLM to reformat it (may throw)
result = await self._try_clean_json_with_llm(output, **kwargs)
output = clean_up_json(result.output or "")
json = try_parse_json_object(output)
# if not return correct formatted json, retry
log.warning("error parsing llm json, retrying")
# If cleaned up json is unparsable, use the LLM to reformat it (may throw)
result = await self._try_clean_json_with_llm(output, **kwargs)
output, json_output = try_parse_json_object(result.output or "")
return LLMOutput[CompletionOutput](
output=output,
json=json,
history=history,
)
return LLMOutput[CompletionOutput](
output=output,
json=json_output,
history=history,
)
async def _try_clean_json_with_llm(
self, output: str, **kwargs: Unpack[LLMInput]

View File

@ -2,13 +2,14 @@
# Licensed under the MIT License
"""Utility functions for the OpenAI API."""
import re
import json
import logging
from collections.abc import Callable
from typing import Any
import tiktoken
from json_repair import repair_json
from openai import (
APIConnectionError,
InternalServerError,
@ -87,17 +88,52 @@ def get_completion_llm_args(
}
def try_parse_json_object(input: str) -> dict:
"""Generate JSON-string output using best-attempt prompting & parsing techniques."""
def try_parse_json_object(input: str) -> tuple[str, dict]:
"""JSON cleaning and formatting utilities."""
"""sometime, the llm return a json string with some extra description, this function will clean it up."""
_pattern = r"\{(.*)\}"
_match = re.search(_pattern, input)
input = "{" + _match.group(1) + "}" if _match else input
"""Clean up json string."""
input = (
input.replace("{{","{")
.replace("}}","}")
.replace('"[{', "[{")
.replace('}]"', "}]")
.replace("\\", " ")
.replace("\\n", " ")
.replace("\n", " ")
.replace("\r", "")
.strip()
)
# Remove JSON Markdown Frame
if input.startswith("```json"):
input = input[len("```json") :]
if input.endswith("```"):
input = input[: len(input) - len("```")]
try:
result = json.loads(input)
except json.JSONDecodeError:
log.exception("error loading json, json=%s", input)
raise
"""Fixup potentially malformed json string using json_repair."""
input = str(repair_json(json_str=input, return_objects=False))
"""Generate JSON-string output using best-attempt prompting & parsing techniques."""
try:
result = json.loads(input)
except json.JSONDecodeError:
log.exception("error loading json, json=%s", input)
return input, {}
else:
if not isinstance(result, dict):
log.exception("not expected dict type. type=%s:", type(result))
return input, {}
return input, result
else:
if not isinstance(result, dict):
raise TypeError
return result
return input, result
def get_sleep_time_from_error(e: Any) -> float:

View File

@ -10,7 +10,7 @@ from enum import Enum
from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT
from graphrag.prompt_tune.loader import MIN_CHUNK_SIZE
from .cli import fine_tune
from .cli import prompt_tune
class DocSelectionType(Enum):
@ -19,6 +19,7 @@ class DocSelectionType(Enum):
ALL = "all"
RANDOM = "random"
TOP = "top"
AUTO = "auto"
def __str__(self):
"""Return the string representation of the enum value."""
@ -46,13 +47,29 @@ if __name__ == "__main__":
parser.add_argument(
"--method",
help="The method to select documents, one of: all, random or top",
help="The method to select documents, one of: all, random, top or auto",
required=False,
type=DocSelectionType,
choices=list(DocSelectionType),
default=DocSelectionType.RANDOM,
)
parser.add_argument(
"--n_subset_max",
help="The number of text chunks to embed when using auto selection method",
required=False,
type=int,
default=300,
)
parser.add_argument(
"--k",
help="The maximum number of documents to select from each centroid when using auto selection method",
required=False,
type=int,
default=15,
)
parser.add_argument(
"--limit",
help="The limit of files to load when doing random or top selection",
@ -69,6 +86,14 @@ if __name__ == "__main__":
default=MAX_TOKEN_COUNT,
)
parser.add_argument(
"--min-examples-required",
help="The minimum number of examples required in entity extraction prompt",
type=int,
required=False,
default=2,
)
parser.add_argument(
"--chunk-size",
help="Max token count for prompt generation",
@ -106,7 +131,7 @@ if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(
fine_tune(
prompt_tune(
args.root,
args.domain,
str(args.method),
@ -116,5 +141,8 @@ if __name__ == "__main__":
args.language,
args.no_entity_types,
args.output,
args.n_subset_max,
args.k,
args.min_examples_required,
)
)

View File

@ -32,7 +32,7 @@ from graphrag.prompt_tune.loader import (
)
async def fine_tune(
async def prompt_tune(
root: str,
domain: str,
select: str = "random",
@ -42,8 +42,11 @@ async def fine_tune(
language: str | None = None,
skip_entity_types: bool = False,
output: str = "prompts",
n_subset_max: int = 300,
k: int = 15,
min_examples_required: int = 2,
):
"""Fine tune the model.
"""Prompt tune the model.
Parameters
----------
@ -55,11 +58,13 @@ async def fine_tune(
- chunk_size: The chunk token size to use.
- skip_entity_types: Skip generating entity types.
- output: The output folder to store the prompts.
- n_subset_max: The number of text chunks to embed when using auto selection method.
- k: The number of documents to select when using auto selection method.
"""
reporter = PrintProgressReporter("")
config = read_config_parameters(root, reporter)
await fine_tune_with_config(
await prompt_tune_with_config(
root,
config,
domain,
@ -71,10 +76,13 @@ async def fine_tune(
skip_entity_types,
output,
reporter,
n_subset_max,
k,
min_examples_required,
)
async def fine_tune_with_config(
async def prompt_tune_with_config(
root: str,
config: GraphRagConfig,
domain: str,
@ -86,8 +94,11 @@ async def fine_tune_with_config(
skip_entity_types: bool = False,
output: str = "prompts",
reporter: ProgressReporter | None = None,
n_subset_max: int = 300,
k: int = 15,
min_examples_required: int = 2,
):
"""Fine tune the model with a configuration.
"""Prompt tune the model with a configuration.
Parameters
----------
@ -101,6 +112,8 @@ async def fine_tune_with_config(
- skip_entity_types: Skip generating entity types.
- output: The output folder to store the prompts.
- reporter: The progress reporter.
- n_subset_max: The number of text chunks to embed when using auto selection method.
- k: The number of documents to select when using auto selection method.
Returns
-------
@ -118,11 +131,13 @@ async def fine_tune_with_config(
select_method=select,
reporter=reporter,
chunk_size=chunk_size,
n_subset_max=n_subset_max,
k=k,
)
# Create LLM from config
llm = load_llm(
"fine_tuning",
"prompt_tuning",
config.llm.type,
NoopVerbCallbacks(),
None,
@ -139,6 +154,7 @@ async def fine_tune_with_config(
language,
max_tokens,
skip_entity_types,
min_examples_required,
)
@ -152,6 +168,7 @@ async def generate_indexing_prompts(
language: str | None = None,
max_tokens: int = MAX_TOKEN_COUNT,
skip_entity_types: bool = False,
min_examples_required: int = 2,
):
"""Generate indexing prompts.
@ -165,6 +182,7 @@ async def generate_indexing_prompts(
- domain: The domain to map the input documents to.
- max_tokens: The maximum number of tokens to use on entity extraction prompts
- skip_entity_types: Skip generating entity types.
- min_examples_required: The minimum number of examples required for entity extraction prompts.
"""
if not domain:
reporter.info("Generating domain...")
@ -221,6 +239,7 @@ async def generate_indexing_prompts(
output_path=output_path,
encoding_model=config.encoding_model,
max_token_count=max_tokens,
min_examples_required=min_examples_required,
)
reporter.info(f"Generated entity extraction prompt, stored in folder {output_path}")

View File

@ -42,7 +42,7 @@ def create_community_summarization_prompt(
output_path = output_path / COMMUNITY_SUMMARIZATION_FILENAME
# Write file to output path
with output_path.open("w") as file:
file.write(prompt)
with output_path.open("wb") as file:
file.write(prompt.encode(encoding="utf-8", errors="strict"))
return prompt

View File

@ -27,6 +27,7 @@ def create_entity_extraction_prompt(
encoding_model: str = defs.ENCODING_MODEL,
json_mode: bool = False,
output_path: Path | None = None,
min_examples_required: int = 2,
) -> str:
"""
Create a prompt for entity extraction.
@ -41,6 +42,7 @@ def create_entity_extraction_prompt(
- max_token_count (int): The maximum number of tokens to use for the prompt
- json_mode (bool): Whether to use JSON mode for the prompt. Default is False
- output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None.
- min_examples_required (int): The minimum number of examples required. Default is 2.
Returns
-------
@ -79,8 +81,8 @@ def create_entity_extraction_prompt(
example_tokens = num_tokens_from_string(example_formatted, model=encoding_model)
# Squeeze in at least one example
if i > 0 and example_tokens > tokens_left:
# Ensure at least three examples are included
if i >= min_examples_required and example_tokens > tokens_left:
break
examples_prompt += example_formatted
@ -99,7 +101,7 @@ def create_entity_extraction_prompt(
output_path = output_path / ENTITY_EXTRACTION_FILENAME
# Write file to output path
with output_path.open("w") as file:
file.write(prompt)
with output_path.open("wb") as file:
file.write(prompt.encode(encoding="utf-8", errors="strict"))
return prompt

View File

@ -30,7 +30,7 @@ def create_entity_summarization_prompt(
output_path = output_path / ENTITY_SUMMARIZATION_FILENAME
# Write file to output path
with output_path.open("w") as file:
file.write(prompt)
with output_path.open("wb") as file:
file.write(prompt.encode(encoding="utf-8", errors="strict"))
return prompt

View File

@ -25,18 +25,18 @@ def read_config_parameters(root: str, reporter: ProgressReporter):
if settings_yaml.exists():
reporter.info(f"Reading settings from {settings_yaml}")
with settings_yaml.open("r") as file:
with settings_yaml.open("rb") as file:
import yaml
data = yaml.safe_load(file)
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root)
if settings_json.exists():
reporter.info(f"Reading settings from {settings_json}")
with settings_json.open("r") as file:
with settings_json.open("rb") as file:
import json
data = json.loads(file.read())
data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root)
reporter.info("Reading settings from environment variables")

View File

@ -5,16 +5,45 @@
from typing import cast
import numpy as np
import pandas as pd
from datashaper import NoopVerbCallbacks, TableContainer, VerbInput
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.input import load_input
from graphrag.index.llm import load_llm_embeddings
from graphrag.index.progress.types import ProgressReporter
from graphrag.index.verbs import chunk
from graphrag.llm.types.llm_types import EmbeddingLLM
MIN_CHUNK_SIZE = 200
MIN_CHUNK_OVERLAP = 0
MIN_CHUNK_SIZE = 200
N_SUBSET_MAX = 300
K = 15
async def _embed_chunks(
text_chunks: pd.DataFrame,
embedding_llm: EmbeddingLLM,
n_subset_max: int = N_SUBSET_MAX,
) -> tuple[pd.DataFrame, np.ndarray]:
"""Convert text chunks into dense text embeddings."""
sampled_text_chunks = text_chunks.sample(n=min(n_subset_max, len(text_chunks)))
embeddings = await embedding_llm(sampled_text_chunks["chunks"].tolist())
return text_chunks, np.array(embeddings.output)
def _sample_chunks_from_embeddings(
text_chunks: pd.DataFrame,
embeddings,
k: int = K,
) -> pd.DataFrame:
"""Sample text chunks from embeddings."""
center = np.mean(embeddings, axis=0)
distances = np.linalg.norm(embeddings - center, axis=1)
nearest_indices = np.argsort(distances)[:k]
return text_chunks.iloc[nearest_indices]
async def load_docs_in_chunks(
@ -24,6 +53,8 @@ async def load_docs_in_chunks(
limit: int,
reporter: ProgressReporter,
chunk_size: int = MIN_CHUNK_SIZE,
n_subset_max: int = N_SUBSET_MAX,
k: int = K,
) -> list[str]:
"""Load docs into chunks for generating prompts."""
dataset = await load_input(config.input, reporter, root)
@ -57,6 +88,22 @@ async def load_docs_in_chunks(
chunks_df = chunks_df[:limit]
elif select_method == "random":
chunks_df = chunks_df.sample(n=limit)
elif select_method == "auto":
if k is None or k <= 0:
msg = "k must be an integer > 0"
raise ValueError(msg)
embedding_llm = load_llm_embeddings(
name="prompt_tuning_embeddings",
llm_type=config.embeddings.resolved_strategy()["llm"]["type"],
callbacks=NoopVerbCallbacks(),
cache=None,
llm_config=config.embeddings.resolved_strategy()["llm"],
)
chunks_df, embeddings = await _embed_chunks(
chunks_df, embedding_llm, n_subset_max=n_subset_max
)
chunks_df = _sample_chunks_from_embeddings(chunks_df, embeddings, k=k)
# Convert the dataset to list form, so we have a list of documents
return chunks_df["chunks"].tolist()

View File

@ -110,7 +110,6 @@ def run_local_search(
final_relationships = pd.read_parquet(
data_path / "create_final_relationships.parquet"
)
final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet")
final_entities = pd.read_parquet(data_path / "create_final_entities.parquet")
final_covariates_path = data_path / "create_final_covariates.parquet"
final_covariates = (
@ -194,18 +193,20 @@ def _read_config_parameters(root: str):
if settings_yaml.exists():
reporter.info(f"Reading settings from {settings_yaml}")
with settings_yaml.open("r") as file:
with settings_yaml.open(
"rb",
) as file:
import yaml
data = yaml.safe_load(file)
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root)
if settings_json.exists():
reporter.info(f"Reading settings from {settings_json}")
with settings_json.open("r") as file:
with settings_json.open("rb") as file:
import json
data = json.loads(file.read())
data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root)
reporter.info("Reading settings from environment variables")

View File

@ -58,6 +58,7 @@ def get_llm(config: GraphRagConfig) -> ChatOpenAI:
else None
),
api_base=config.llm.api_base,
organization=config.llm.organization,
model=config.llm.model,
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
deployment_name=config.llm.deployment_name,
@ -89,6 +90,7 @@ def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
else None
),
api_base=config.embeddings.llm.api_base,
organization=config.llm.organization,
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
model=config.embeddings.llm.model,
deployment_name=config.embeddings.llm.deployment_name,

View File

@ -13,7 +13,7 @@ from typing import Any
import pandas as pd
import tiktoken
from graphrag.index.utils.json import clean_up_json
from graphrag.llm.openai.utils import try_parse_json_object
from graphrag.query.context_builder.builders import GlobalContextBuilder
from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
@ -32,7 +32,7 @@ from graphrag.query.structured_search.global_search.reduce_system_prompt import
NO_DATA_ANSWER,
REDUCE_SYSTEM_PROMPT,
)
from graphrag.llm.openai.utils import try_parse_json_object
DEFAULT_MAP_LLM_PARAMS = {
"max_tokens": 1000,
"temperature": 0.0,
@ -188,7 +188,6 @@ class GlobalSearch(BaseSearch):
processed_response = self.parse_search_response(search_response)
except ValueError:
# Clean up and retry parse
search_response = clean_up_json(search_response)
try:
# parse search response json
processed_response = self.parse_search_response(search_response)
@ -229,6 +228,10 @@ class GlobalSearch(BaseSearch):
list[dict[str, Any]]
A list of key points, each key point is a dictionary with "answer" and "score" keys
"""
search_response,_j = try_parse_json_object(search_response)
if _j =={}:
return [{"answer":"not avaliable","score": 0}]
parsed_elements = json.loads(search_response)["points"]
return [
{

1483
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -85,6 +85,7 @@ typing-extensions = "^4.12.2"
#Azure
azure-storage-blob = "^12.19.0"
azure-identity = "^1.17.1"
json-repair = "^0.25.3"
[tool.poetry.group.dev.dependencies]
coverage = "^7.6.0"

View File

@ -38,8 +38,7 @@ def _load_fixtures():
continue
config_file = fixtures_path / subfolder / "config.json"
with config_file.open() as f:
params.append((subfolder, json.load(f)))
params.append((subfolder, json.loads(config_file.read_bytes().decode("utf-8"))))
return params
@ -104,8 +103,7 @@ async def prepare_azurite_data(input_path: str, azure: dict) -> Callable[[], Non
csv_files = list((root / "input").glob("*.csv"))
data_files = txt_files + csv_files
for data_file in data_files:
with data_file.open(encoding="utf8") as f:
text = f.read()
text = data_file.read_bytes().decode("utf-8")
file_path = (
str(Path(input_base_dir) / data_file.name)
if input_base_dir
@ -166,8 +164,7 @@ class TestIndexer:
assert artifacts.exists(), "artifact folder does not exist"
# Check stats for all workflow
with (artifacts / "stats.json").open() as f:
stats = json.load(f)
stats = json.loads((artifacts / "stats.json").read_bytes().decode("utf-8"))
# Check all workflows run
expected_workflows = set(workflow_config.keys())