mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Task/remove dynamic retries (#1941)
Some checks are pending
gh-pages / build (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.11) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.10) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.11) (push) Waiting to run
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Publish (pypi) / Upload release to PyPI (push) Waiting to run
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Spellcheck / spellcheck (push) Waiting to run
Some checks are pending
gh-pages / build (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.11) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.10) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.11) (push) Waiting to run
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Publish (pypi) / Upload release to PyPI (push) Waiting to run
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Spellcheck / spellcheck (push) Waiting to run
* Remove max retries. Update Typer args * Format * Semver * Fix typo * Ruff and Typos * Format
This commit is contained in:
parent
36948b8d2e
commit
24018c6155
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "minor",
|
||||
"description": "Remove Dynamic Max Retries support. Refactor typer typing in cli interface"
|
||||
}
|
||||
@ -17,7 +17,7 @@ import annotated_types
|
||||
from pydantic import PositiveInt, validate_call
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.defaults import graphrag_config_defaults, language_model_defaults
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
@ -109,15 +109,6 @@ async def generate_indexing_prompts(
|
||||
logger.info("Retrieving language model configuration...")
|
||||
default_llm_settings = config.get_language_model_config(PROMPT_TUNING_MODEL_ID)
|
||||
|
||||
# if max_retries is not set, inject a dynamically assigned value based on the number of expected LLM calls
|
||||
# to be made or fallback to a default value in the worst case
|
||||
if default_llm_settings.max_retries < -1:
|
||||
default_llm_settings.max_retries = min(
|
||||
len(doc_list), language_model_defaults.max_retries
|
||||
)
|
||||
msg = f"max_retries not set, using default value: {default_llm_settings.max_retries}"
|
||||
logger.warning(msg)
|
||||
|
||||
logger.info("Creating language model...")
|
||||
llm = ModelManager().register_chat(
|
||||
name="prompt_tuning",
|
||||
|
||||
@ -7,7 +7,6 @@ import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
|
||||
@ -78,25 +77,40 @@ def path_autocomplete(
|
||||
return completer
|
||||
|
||||
|
||||
CONFIG_AUTOCOMPLETE = path_autocomplete(
|
||||
file_okay=True,
|
||||
dir_okay=False,
|
||||
match_wildcard="*.yaml",
|
||||
readable=True,
|
||||
)
|
||||
|
||||
ROOT_AUTOCOMPLETE = path_autocomplete(
|
||||
file_okay=False,
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
match_wildcard="*",
|
||||
)
|
||||
|
||||
|
||||
@app.command("init")
|
||||
def _initialize_cli(
|
||||
root: Annotated[
|
||||
Path,
|
||||
typer.Option(
|
||||
help="The project root directory.",
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
autocompletion=path_autocomplete(
|
||||
file_okay=False, dir_okay=True, writable=True, match_wildcard="*"
|
||||
),
|
||||
),
|
||||
],
|
||||
force: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Force initialization even if the project already exists."),
|
||||
] = False,
|
||||
):
|
||||
root: Path = typer.Option(
|
||||
Path(),
|
||||
"--root",
|
||||
"-r",
|
||||
help="The project root directory.",
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
autocompletion=ROOT_AUTOCOMPLETE,
|
||||
),
|
||||
force: bool = typer.Option(
|
||||
False,
|
||||
"--force",
|
||||
"-f",
|
||||
help="Force initialization even if the project already exists.",
|
||||
),
|
||||
) -> None:
|
||||
"""Generate a default configuration file."""
|
||||
from graphrag.cli.initialize import initialize_project_at
|
||||
|
||||
@ -105,60 +119,80 @@ def _initialize_cli(
|
||||
|
||||
@app.command("index")
|
||||
def _index_cli(
|
||||
config: Annotated[
|
||||
Path | None,
|
||||
typer.Option(
|
||||
help="The configuration to use.", exists=True, file_okay=True, readable=True
|
||||
config: Path | None = typer.Option(
|
||||
None,
|
||||
"--config",
|
||||
"-c",
|
||||
help="The configuration to use.",
|
||||
exists=True,
|
||||
file_okay=True,
|
||||
readable=True,
|
||||
autocompletion=CONFIG_AUTOCOMPLETE,
|
||||
),
|
||||
root: Path = typer.Option(
|
||||
Path(),
|
||||
"--root",
|
||||
"-r",
|
||||
help="The project root directory.",
|
||||
exists=True,
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
autocompletion=ROOT_AUTOCOMPLETE,
|
||||
),
|
||||
method: IndexingMethod = typer.Option(
|
||||
IndexingMethod.Standard.value,
|
||||
"--method",
|
||||
"-m",
|
||||
help="The indexing method to use.",
|
||||
),
|
||||
verbose: bool = typer.Option(
|
||||
False,
|
||||
"--verbose",
|
||||
"-v",
|
||||
help="Run the indexing pipeline with verbose logging",
|
||||
),
|
||||
memprofile: bool = typer.Option(
|
||||
False,
|
||||
"--memprofile",
|
||||
help="Run the indexing pipeline with memory profiling",
|
||||
),
|
||||
logger: LoggerType = typer.Option(
|
||||
LoggerType.RICH.value,
|
||||
"--logger",
|
||||
help="The progress logger to use.",
|
||||
),
|
||||
dry_run: bool = typer.Option(
|
||||
False,
|
||||
"--dry-run",
|
||||
help=(
|
||||
"Run the indexing pipeline without executing any steps "
|
||||
"to inspect and validate the configuration."
|
||||
),
|
||||
] = None,
|
||||
root: Annotated[
|
||||
Path,
|
||||
typer.Option(
|
||||
help="The project root directory.",
|
||||
exists=True,
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
autocompletion=path_autocomplete(
|
||||
file_okay=False, dir_okay=True, writable=True, match_wildcard="*"
|
||||
),
|
||||
),
|
||||
cache: bool = typer.Option(
|
||||
True,
|
||||
"--cache/--no-cache",
|
||||
help="Use LLM cache.",
|
||||
),
|
||||
skip_validation: bool = typer.Option(
|
||||
False,
|
||||
"--skip-validation",
|
||||
help="Skip any preflight validation. Useful when running no LLM steps.",
|
||||
),
|
||||
output: Path | None = typer.Option(
|
||||
None,
|
||||
"--output",
|
||||
"-o",
|
||||
help=(
|
||||
"Indexing pipeline output directory. "
|
||||
"Overrides output.base_dir in the configuration file."
|
||||
),
|
||||
] = Path(), # set default to current directory
|
||||
method: Annotated[
|
||||
IndexingMethod, typer.Option(help="The indexing method to use.")
|
||||
] = IndexingMethod.Standard,
|
||||
verbose: Annotated[
|
||||
bool, typer.Option(help="Run the indexing pipeline with verbose logging")
|
||||
] = False,
|
||||
memprofile: Annotated[
|
||||
bool, typer.Option(help="Run the indexing pipeline with memory profiling")
|
||||
] = False,
|
||||
logger: Annotated[
|
||||
LoggerType, typer.Option(help="The progress logger to use.")
|
||||
] = LoggerType.RICH,
|
||||
dry_run: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
help="Run the indexing pipeline without executing any steps to inspect and validate the configuration."
|
||||
),
|
||||
] = False,
|
||||
cache: Annotated[bool, typer.Option(help="Use LLM cache.")] = True,
|
||||
skip_validation: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
help="Skip any preflight validation. Useful when running no LLM steps."
|
||||
),
|
||||
] = False,
|
||||
output: Annotated[
|
||||
Path | None,
|
||||
typer.Option(
|
||||
help="Indexing pipeline output directory. Overrides output.base_dir in the configuration file.",
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
),
|
||||
] = None,
|
||||
):
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
),
|
||||
) -> None:
|
||||
"""Build a knowledge graph index."""
|
||||
from graphrag.cli.index import index_cli
|
||||
|
||||
@ -178,51 +212,72 @@ def _index_cli(
|
||||
|
||||
@app.command("update")
|
||||
def _update_cli(
|
||||
config: Annotated[
|
||||
Path | None,
|
||||
typer.Option(
|
||||
help="The configuration to use.", exists=True, file_okay=True, readable=True
|
||||
config: Path | None = typer.Option(
|
||||
None,
|
||||
"--config",
|
||||
"-c",
|
||||
help="The configuration to use.",
|
||||
exists=True,
|
||||
file_okay=True,
|
||||
readable=True,
|
||||
autocompletion=CONFIG_AUTOCOMPLETE,
|
||||
),
|
||||
root: Path = typer.Option(
|
||||
Path(),
|
||||
"--root",
|
||||
"-r",
|
||||
help="The project root directory.",
|
||||
exists=True,
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
autocompletion=ROOT_AUTOCOMPLETE,
|
||||
),
|
||||
method: IndexingMethod = typer.Option(
|
||||
IndexingMethod.Standard.value,
|
||||
"--method",
|
||||
"-m",
|
||||
help="The indexing method to use.",
|
||||
),
|
||||
verbose: bool = typer.Option(
|
||||
False,
|
||||
"--verbose",
|
||||
"-v",
|
||||
help="Run the indexing pipeline with verbose logging.",
|
||||
),
|
||||
memprofile: bool = typer.Option(
|
||||
False,
|
||||
"--memprofile",
|
||||
help="Run the indexing pipeline with memory profiling.",
|
||||
),
|
||||
logger: LoggerType = typer.Option(
|
||||
LoggerType.RICH.value,
|
||||
"--logger",
|
||||
help="The progress logger to use.",
|
||||
),
|
||||
cache: bool = typer.Option(
|
||||
True,
|
||||
"--cache/--no-cache",
|
||||
help="Use LLM cache.",
|
||||
),
|
||||
skip_validation: bool = typer.Option(
|
||||
False,
|
||||
"--skip-validation",
|
||||
help="Skip any preflight validation. Useful when running no LLM steps.",
|
||||
),
|
||||
output: Path | None = typer.Option(
|
||||
None,
|
||||
"--output",
|
||||
"-o",
|
||||
help=(
|
||||
"Indexing pipeline output directory. "
|
||||
"Overrides output.base_dir in the configuration file."
|
||||
),
|
||||
] = None,
|
||||
root: Annotated[
|
||||
Path,
|
||||
typer.Option(
|
||||
help="The project root directory.",
|
||||
exists=True,
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
),
|
||||
] = Path(), # set default to current directory
|
||||
method: Annotated[
|
||||
IndexingMethod, typer.Option(help="The indexing method to use.")
|
||||
] = IndexingMethod.Standard,
|
||||
verbose: Annotated[
|
||||
bool, typer.Option(help="Run the indexing pipeline with verbose logging")
|
||||
] = False,
|
||||
memprofile: Annotated[
|
||||
bool, typer.Option(help="Run the indexing pipeline with memory profiling")
|
||||
] = False,
|
||||
logger: Annotated[
|
||||
LoggerType, typer.Option(help="The progress logger to use.")
|
||||
] = LoggerType.RICH,
|
||||
cache: Annotated[bool, typer.Option(help="Use LLM cache.")] = True,
|
||||
skip_validation: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
help="Skip any preflight validation. Useful when running no LLM steps."
|
||||
),
|
||||
] = False,
|
||||
output: Annotated[
|
||||
Path | None,
|
||||
typer.Option(
|
||||
help="Indexing pipeline output directory. Overrides output.base_dir in the configuration file.",
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
),
|
||||
] = None,
|
||||
):
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
),
|
||||
) -> None:
|
||||
"""
|
||||
Update an existing knowledge graph index.
|
||||
|
||||
@ -245,104 +300,107 @@ def _update_cli(
|
||||
|
||||
@app.command("prompt-tune")
|
||||
def _prompt_tune_cli(
|
||||
root: Annotated[
|
||||
Path,
|
||||
typer.Option(
|
||||
help="The project root directory.",
|
||||
exists=True,
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
autocompletion=path_autocomplete(
|
||||
file_okay=False, dir_okay=True, writable=True, match_wildcard="*"
|
||||
),
|
||||
root: Path = typer.Option(
|
||||
Path(),
|
||||
"--root",
|
||||
"-r",
|
||||
help="The project root directory.",
|
||||
exists=True,
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
autocompletion=ROOT_AUTOCOMPLETE,
|
||||
),
|
||||
config: Path | None = typer.Option(
|
||||
None,
|
||||
"--config",
|
||||
"-c",
|
||||
help="The configuration to use.",
|
||||
exists=True,
|
||||
file_okay=True,
|
||||
readable=True,
|
||||
autocompletion=CONFIG_AUTOCOMPLETE,
|
||||
),
|
||||
verbose: bool = typer.Option(
|
||||
False,
|
||||
"--verbose",
|
||||
"-v",
|
||||
help="Run the prompt tuning pipeline with verbose logging.",
|
||||
),
|
||||
logger: LoggerType = typer.Option(
|
||||
LoggerType.RICH.value,
|
||||
"--logger",
|
||||
help="The progress logger to use.",
|
||||
),
|
||||
domain: str | None = typer.Option(
|
||||
None,
|
||||
"--domain",
|
||||
help=(
|
||||
"The domain your input data is related to. "
|
||||
"For example 'space science', 'microbiology', 'environmental news'. "
|
||||
"If not defined, a domain will be inferred from the input data."
|
||||
),
|
||||
] = Path(), # set default to current directory
|
||||
config: Annotated[
|
||||
Path | None,
|
||||
typer.Option(
|
||||
help="The configuration to use.",
|
||||
exists=True,
|
||||
file_okay=True,
|
||||
readable=True,
|
||||
autocompletion=path_autocomplete(
|
||||
file_okay=True, dir_okay=False, match_wildcard="*"
|
||||
),
|
||||
),
|
||||
] = None,
|
||||
verbose: Annotated[
|
||||
bool, typer.Option(help="Run the prompt tuning pipeline with verbose logging")
|
||||
] = False,
|
||||
logger: Annotated[
|
||||
LoggerType, typer.Option(help="The progress logger to use.")
|
||||
] = LoggerType.RICH,
|
||||
domain: Annotated[
|
||||
str | None,
|
||||
typer.Option(
|
||||
help="The domain your input data is related to. For example 'space science', 'microbiology', 'environmental news'. If not defined, a domain will be inferred from the input data."
|
||||
),
|
||||
] = None,
|
||||
selection_method: Annotated[
|
||||
DocSelectionType, typer.Option(help="The text chunk selection method.")
|
||||
] = DocSelectionType.RANDOM,
|
||||
n_subset_max: Annotated[
|
||||
int,
|
||||
typer.Option(
|
||||
help="The number of text chunks to embed when --selection-method=auto."
|
||||
),
|
||||
] = N_SUBSET_MAX,
|
||||
k: Annotated[
|
||||
int,
|
||||
typer.Option(
|
||||
help="The maximum number of documents to select from each centroid when --selection-method=auto."
|
||||
),
|
||||
] = K,
|
||||
limit: Annotated[
|
||||
int,
|
||||
typer.Option(
|
||||
help="The number of documents to load when --selection-method={random,top}."
|
||||
),
|
||||
] = LIMIT,
|
||||
max_tokens: Annotated[
|
||||
int, typer.Option(help="The max token count for prompt generation.")
|
||||
] = MAX_TOKEN_COUNT,
|
||||
min_examples_required: Annotated[
|
||||
int,
|
||||
typer.Option(
|
||||
help="The minimum number of examples to generate/include in the entity extraction prompt."
|
||||
),
|
||||
] = 2,
|
||||
chunk_size: Annotated[
|
||||
int,
|
||||
typer.Option(
|
||||
help="The size of each example text chunk. Overrides chunks.size in the configuration file."
|
||||
),
|
||||
] = graphrag_config_defaults.chunks.size,
|
||||
overlap: Annotated[
|
||||
int,
|
||||
typer.Option(
|
||||
help="The overlap size for chunking documents. Overrides chunks.overlap in the configuration file"
|
||||
),
|
||||
] = graphrag_config_defaults.chunks.overlap,
|
||||
language: Annotated[
|
||||
str | None,
|
||||
typer.Option(
|
||||
help="The primary language used for inputs and outputs in graphrag prompts."
|
||||
),
|
||||
] = None,
|
||||
discover_entity_types: Annotated[
|
||||
bool, typer.Option(help="Discover and extract unspecified entity types.")
|
||||
] = True,
|
||||
output: Annotated[
|
||||
Path,
|
||||
typer.Option(
|
||||
help="The directory to save prompts to, relative to the project root directory.",
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
),
|
||||
] = Path("prompts"),
|
||||
):
|
||||
),
|
||||
selection_method: DocSelectionType = typer.Option(
|
||||
DocSelectionType.RANDOM.value,
|
||||
"--selection-method",
|
||||
help="The text chunk selection method.",
|
||||
),
|
||||
n_subset_max: int = typer.Option(
|
||||
N_SUBSET_MAX,
|
||||
"--n-subset-max",
|
||||
help="The number of text chunks to embed when --selection-method=auto.",
|
||||
),
|
||||
k: int = typer.Option(
|
||||
K,
|
||||
"--k",
|
||||
help="The maximum number of documents to select from each centroid when --selection-method=auto.",
|
||||
),
|
||||
limit: int = typer.Option(
|
||||
LIMIT,
|
||||
"--limit",
|
||||
help="The number of documents to load when --selection-method={random,top}.",
|
||||
),
|
||||
max_tokens: int = typer.Option(
|
||||
MAX_TOKEN_COUNT,
|
||||
"--max-tokens",
|
||||
help="The max token count for prompt generation.",
|
||||
),
|
||||
min_examples_required: int = typer.Option(
|
||||
2,
|
||||
"--min-examples-required",
|
||||
help="The minimum number of examples to generate/include in the entity extraction prompt.",
|
||||
),
|
||||
chunk_size: int = typer.Option(
|
||||
graphrag_config_defaults.chunks.size,
|
||||
"--chunk-size",
|
||||
help="The size of each example text chunk. Overrides chunks.size in the configuration file.",
|
||||
),
|
||||
overlap: int = typer.Option(
|
||||
graphrag_config_defaults.chunks.overlap,
|
||||
"--overlap",
|
||||
help="The overlap size for chunking documents. Overrides chunks.overlap in the configuration file.",
|
||||
),
|
||||
language: str | None = typer.Option(
|
||||
None,
|
||||
"--language",
|
||||
help="The primary language used for inputs and outputs in graphrag prompts.",
|
||||
),
|
||||
discover_entity_types: bool = typer.Option(
|
||||
True,
|
||||
"--discover-entity-types/--no-discover-entity-types",
|
||||
help="Discover and extract unspecified entity types.",
|
||||
),
|
||||
output: Path = typer.Option(
|
||||
Path("prompts"),
|
||||
"--output",
|
||||
"-o",
|
||||
help="The directory to save prompts to, relative to the project root directory.",
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
),
|
||||
) -> None:
|
||||
"""Generate custom graphrag prompts with your own data (i.e. auto templating)."""
|
||||
import asyncio
|
||||
|
||||
@ -373,66 +431,77 @@ def _prompt_tune_cli(
|
||||
|
||||
@app.command("query")
|
||||
def _query_cli(
|
||||
method: Annotated[SearchMethod, typer.Option(help="The query algorithm to use.")],
|
||||
query: Annotated[str, typer.Option(help="The query to execute.")],
|
||||
config: Annotated[
|
||||
Path | None,
|
||||
typer.Option(
|
||||
help="The configuration to use.",
|
||||
exists=True,
|
||||
file_okay=True,
|
||||
readable=True,
|
||||
autocompletion=path_autocomplete(
|
||||
file_okay=True, dir_okay=False, match_wildcard="*"
|
||||
),
|
||||
method: SearchMethod = typer.Option(
|
||||
...,
|
||||
"--method",
|
||||
"-m",
|
||||
help="The query algorithm to use.",
|
||||
),
|
||||
query: str = typer.Option(
|
||||
...,
|
||||
"--query",
|
||||
"-q",
|
||||
help="The query to execute.",
|
||||
),
|
||||
config: Path | None = typer.Option(
|
||||
None,
|
||||
"--config",
|
||||
"-c",
|
||||
help="The configuration to use.",
|
||||
exists=True,
|
||||
file_okay=True,
|
||||
readable=True,
|
||||
autocompletion=CONFIG_AUTOCOMPLETE,
|
||||
),
|
||||
data: Path | None = typer.Option(
|
||||
None,
|
||||
"--data",
|
||||
"-d",
|
||||
help="Index output directory (contains the parquet files).",
|
||||
exists=True,
|
||||
dir_okay=True,
|
||||
readable=True,
|
||||
resolve_path=True,
|
||||
autocompletion=ROOT_AUTOCOMPLETE,
|
||||
),
|
||||
root: Path = typer.Option(
|
||||
Path(),
|
||||
"--root",
|
||||
"-r",
|
||||
help="The project root directory.",
|
||||
exists=True,
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
autocompletion=ROOT_AUTOCOMPLETE,
|
||||
),
|
||||
community_level: int = typer.Option(
|
||||
2,
|
||||
"--community-level",
|
||||
help=(
|
||||
"Leiden hierarchy level from which to load community reports. "
|
||||
"Higher values represent smaller communities."
|
||||
),
|
||||
] = None,
|
||||
data: Annotated[
|
||||
Path | None,
|
||||
typer.Option(
|
||||
help="Indexing pipeline output directory (i.e. contains the parquet files).",
|
||||
exists=True,
|
||||
dir_okay=True,
|
||||
readable=True,
|
||||
resolve_path=True,
|
||||
autocompletion=path_autocomplete(
|
||||
file_okay=False, dir_okay=True, match_wildcard="*"
|
||||
),
|
||||
),
|
||||
dynamic_community_selection: bool = typer.Option(
|
||||
False,
|
||||
"--dynamic-community-selection/--no-dynamic-selection",
|
||||
help="Use global search with dynamic community selection.",
|
||||
),
|
||||
response_type: str = typer.Option(
|
||||
"Multiple Paragraphs",
|
||||
"--response-type",
|
||||
help=(
|
||||
"Free-form description of the desired response format "
|
||||
"(e.g. 'Single Sentence', 'List of 3-7 Points', etc.)."
|
||||
),
|
||||
] = None,
|
||||
root: Annotated[
|
||||
Path,
|
||||
typer.Option(
|
||||
help="The project root directory.",
|
||||
exists=True,
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
autocompletion=path_autocomplete(
|
||||
file_okay=False, dir_okay=True, match_wildcard="*"
|
||||
),
|
||||
),
|
||||
] = Path(), # set default to current directory
|
||||
community_level: Annotated[
|
||||
int,
|
||||
typer.Option(
|
||||
help="The community level in the Leiden community hierarchy from which to load community reports. Higher values represent reports from smaller communities."
|
||||
),
|
||||
] = 2,
|
||||
dynamic_community_selection: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Use global search with dynamic community selection."),
|
||||
] = False,
|
||||
response_type: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
help="Free form text describing the response type and format, can be anything, e.g. Multiple Paragraphs, Single Paragraph, Single Sentence, List of 3-7 Points, Single Page, Multi-Page Report. Default: Multiple Paragraphs"
|
||||
),
|
||||
] = "Multiple Paragraphs",
|
||||
streaming: Annotated[
|
||||
bool, typer.Option(help="Print response in a streaming manner.")
|
||||
] = False,
|
||||
):
|
||||
),
|
||||
streaming: bool = typer.Option(
|
||||
False,
|
||||
"--streaming/--no-streaming",
|
||||
help="Print the response in a streaming manner.",
|
||||
),
|
||||
) -> None:
|
||||
"""Query a knowledge graph index."""
|
||||
from graphrag.cli.query import (
|
||||
run_basic_search,
|
||||
|
||||
@ -33,7 +33,7 @@ models:
|
||||
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed
|
||||
async_mode: {language_model_defaults.async_mode.value} # or asyncio
|
||||
retry_strategy: native
|
||||
max_retries: -1 # set to -1 for dynamic retry logic (most optimal setting based on server response)
|
||||
max_retries: {language_model_defaults.max_retries}
|
||||
tokens_per_minute: {language_model_defaults.tokens_per_minute} # set to null to disable rate limiting
|
||||
requests_per_minute: {language_model_defaults.requests_per_minute} # set to null to disable rate limiting
|
||||
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
|
||||
@ -51,7 +51,7 @@ models:
|
||||
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed
|
||||
async_mode: {language_model_defaults.async_mode.value} # or asyncio
|
||||
retry_strategy: native
|
||||
max_retries: -1 # set to -1 for dynamic retry logic (most optimal setting based on server response)
|
||||
max_retries: {language_model_defaults.max_retries}
|
||||
tokens_per_minute: {language_model_defaults.tokens_per_minute} # set to null to disable rate limiting
|
||||
requests_per_minute: {language_model_defaults.requests_per_minute} # set to null to disable rate limiting
|
||||
|
||||
|
||||
@ -198,10 +198,38 @@ class LanguageModelConfig(BaseModel):
|
||||
description="The number of tokens per minute to use for the LLM service.",
|
||||
default=language_model_defaults.tokens_per_minute,
|
||||
)
|
||||
|
||||
def _validate_tokens_per_minute(self) -> None:
|
||||
"""Validate the tokens per minute.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the tokens per minute is less than 0.
|
||||
"""
|
||||
# If the value is a number, check if it is less than 1
|
||||
if isinstance(self.tokens_per_minute, int) and self.tokens_per_minute < 1:
|
||||
msg = f"Tokens per minute must be a non zero positive number, 'auto' or null. Suggested value: {language_model_defaults.tokens_per_minute}."
|
||||
raise ValueError(msg)
|
||||
|
||||
requests_per_minute: int | Literal["auto"] | None = Field(
|
||||
description="The number of requests per minute to use for the LLM service.",
|
||||
default=language_model_defaults.requests_per_minute,
|
||||
)
|
||||
|
||||
def _validate_requests_per_minute(self) -> None:
|
||||
"""Validate the requests per minute.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the requests per minute is less than 0.
|
||||
"""
|
||||
# If the value is a number, check if it is less than 1
|
||||
if isinstance(self.requests_per_minute, int) and self.requests_per_minute < 1:
|
||||
msg = f"Requests per minute must be a non zero positive number, 'auto' or null. Suggested value: {language_model_defaults.requests_per_minute}."
|
||||
raise ValueError(msg)
|
||||
|
||||
retry_strategy: str = Field(
|
||||
description="The retry strategy to use for the LLM service.",
|
||||
default=language_model_defaults.retry_strategy,
|
||||
@ -210,6 +238,19 @@ class LanguageModelConfig(BaseModel):
|
||||
description="The maximum number of retries to use for the LLM service.",
|
||||
default=language_model_defaults.max_retries,
|
||||
)
|
||||
|
||||
def _validate_max_retries(self) -> None:
|
||||
"""Validate the maximum retries.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the maximum retries is less than 0.
|
||||
"""
|
||||
if self.max_retries < 1:
|
||||
msg = f"Maximum retries must be greater than or equal to 1. Suggested value: {language_model_defaults.max_retries}."
|
||||
raise ValueError(msg)
|
||||
|
||||
max_retry_wait: float = Field(
|
||||
description="The maximum retry wait to use for the LLM service.",
|
||||
default=language_model_defaults.max_retry_wait,
|
||||
@ -279,6 +320,9 @@ class LanguageModelConfig(BaseModel):
|
||||
self._validate_type()
|
||||
self._validate_auth_type()
|
||||
self._validate_api_key()
|
||||
self._validate_tokens_per_minute()
|
||||
self._validate_requests_per_minute()
|
||||
self._validate_max_retries()
|
||||
self._validate_azure_settings()
|
||||
self._validate_encoding_model()
|
||||
return self
|
||||
|
||||
@ -109,10 +109,6 @@ async def _text_embed_with_vector_store(
|
||||
strategy_exec = load_strategy(strategy_type)
|
||||
strategy_config = {**strategy}
|
||||
|
||||
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
|
||||
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
|
||||
strategy_config["llm"]["max_retries"] = len(input)
|
||||
|
||||
# Get vector-storage configuration
|
||||
insert_batch_size: int = (
|
||||
vector_store_config.get("batch_size") or DEFAULT_EMBEDDING_BATCH_SIZE
|
||||
|
||||
@ -50,10 +50,6 @@ async def extract_covariates(
|
||||
strategy = strategy or {}
|
||||
strategy_config = {**strategy}
|
||||
|
||||
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
|
||||
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
|
||||
strategy_config["llm"]["max_retries"] = len(input)
|
||||
|
||||
async def run_strategy(row):
|
||||
text = row[column]
|
||||
result = await run_extract_claims(
|
||||
|
||||
@ -45,10 +45,6 @@ async def extract_graph(
|
||||
)
|
||||
strategy_config = {**strategy}
|
||||
|
||||
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
|
||||
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
|
||||
strategy_config["llm"]["max_retries"] = len(text_units)
|
||||
|
||||
num_started = 0
|
||||
|
||||
async def run_strategy(row):
|
||||
|
||||
@ -45,10 +45,6 @@ async def summarize_communities(
|
||||
strategy_exec = load_strategy(strategy["type"])
|
||||
strategy_config = {**strategy}
|
||||
|
||||
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
|
||||
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
|
||||
strategy_config["llm"]["max_retries"] = len(nodes)
|
||||
|
||||
community_hierarchy = (
|
||||
communities.explode("children")
|
||||
.rename({"children": "sub_community"}, axis=1)
|
||||
|
||||
@ -36,10 +36,6 @@ async def summarize_descriptions(
|
||||
)
|
||||
strategy_config = {**strategy}
|
||||
|
||||
# if max_retries is not set, inject a dynamically assigned value based on the maximum number of expected LLM calls to be made
|
||||
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
|
||||
strategy_config["llm"]["max_retries"] = len(entities_df) + len(relationships_df)
|
||||
|
||||
async def get_summarized(
|
||||
nodes: pd.DataFrame, edges: pd.DataFrame, semaphore: asyncio.Semaphore
|
||||
):
|
||||
|
||||
@ -7,7 +7,6 @@ import asyncio
|
||||
import sys
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.defaults import language_model_defaults
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.logger.print_progress import ProgressLogger
|
||||
@ -18,9 +17,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
|
||||
# Validate Chat LLM configs
|
||||
# TODO: Replace default_chat_model with a way to select the model
|
||||
default_llm_settings = parameters.get_language_model_config("default_chat_model")
|
||||
# if max_retries is not set, set it to the default value
|
||||
if default_llm_settings.max_retries == -1:
|
||||
default_llm_settings.max_retries = language_model_defaults.max_retries
|
||||
|
||||
llm = ModelManager().register_chat(
|
||||
name="test-llm",
|
||||
model_type=default_llm_settings.type,
|
||||
@ -40,8 +37,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
|
||||
embedding_llm_settings = parameters.get_language_model_config(
|
||||
parameters.embed_text.model_id
|
||||
)
|
||||
if embedding_llm_settings.max_retries == -1:
|
||||
embedding_llm_settings.max_retries = language_model_defaults.max_retries
|
||||
|
||||
embed_llm = ModelManager().register_embedding(
|
||||
name="test-embed-llm",
|
||||
model_type=embedding_llm_settings.type,
|
||||
|
||||
@ -52,11 +52,6 @@ def get_local_search_engine(
|
||||
"""Create a local search engine based on data + configuration."""
|
||||
model_settings = config.get_language_model_config(config.local_search.chat_model_id)
|
||||
|
||||
if model_settings.max_retries == -1:
|
||||
model_settings.max_retries = (
|
||||
len(reports) + len(entities) + len(relationships) + len(covariates)
|
||||
)
|
||||
|
||||
chat_model = ModelManager().get_or_create_chat_model(
|
||||
name="local_search_chat",
|
||||
model_type=model_settings.type,
|
||||
@ -66,10 +61,7 @@ def get_local_search_engine(
|
||||
embedding_settings = config.get_language_model_config(
|
||||
config.local_search.embedding_model_id
|
||||
)
|
||||
if embedding_settings.max_retries == -1:
|
||||
embedding_settings.max_retries = (
|
||||
len(reports) + len(entities) + len(relationships)
|
||||
)
|
||||
|
||||
embedding_model = ModelManager().get_or_create_embedding_model(
|
||||
name="local_search_embedding",
|
||||
model_type=embedding_settings.type,
|
||||
@ -134,8 +126,6 @@ def get_global_search_engine(
|
||||
config.global_search.chat_model_id
|
||||
)
|
||||
|
||||
if model_settings.max_retries == -1:
|
||||
model_settings.max_retries = len(reports) + len(entities)
|
||||
model = ModelManager().get_or_create_chat_model(
|
||||
name="global_search",
|
||||
model_type=model_settings.type,
|
||||
@ -220,13 +210,6 @@ def get_drift_search_engine(
|
||||
config.drift_search.chat_model_id
|
||||
)
|
||||
|
||||
if chat_model_settings.max_retries == -1:
|
||||
chat_model_settings.max_retries = (
|
||||
config.drift_search.drift_k_followups
|
||||
* config.drift_search.primer_folds
|
||||
* config.drift_search.n_depth
|
||||
)
|
||||
|
||||
chat_model = ModelManager().get_or_create_chat_model(
|
||||
name="drift_search_chat",
|
||||
model_type=chat_model_settings.type,
|
||||
@ -237,11 +220,6 @@ def get_drift_search_engine(
|
||||
config.drift_search.embedding_model_id
|
||||
)
|
||||
|
||||
if embedding_model_settings.max_retries == -1:
|
||||
embedding_model_settings.max_retries = (
|
||||
len(reports) + len(entities) + len(relationships)
|
||||
)
|
||||
|
||||
embedding_model = ModelManager().get_or_create_embedding_model(
|
||||
name="drift_search_embedding",
|
||||
model_type=embedding_model_settings.type,
|
||||
@ -283,9 +261,6 @@ def get_basic_search_engine(
|
||||
config.basic_search.chat_model_id
|
||||
)
|
||||
|
||||
if chat_model_settings.max_retries == -1:
|
||||
chat_model_settings.max_retries = len(text_units)
|
||||
|
||||
chat_model = ModelManager().get_or_create_chat_model(
|
||||
name="basic_search_chat",
|
||||
model_type=chat_model_settings.type,
|
||||
@ -295,8 +270,6 @@ def get_basic_search_engine(
|
||||
embedding_model_settings = config.get_language_model_config(
|
||||
config.basic_search.embedding_model_id
|
||||
)
|
||||
if embedding_model_settings.max_retries == -1:
|
||||
embedding_model_settings.max_retries = len(text_units)
|
||||
|
||||
embedding_model = ModelManager().get_or_create_embedding_model(
|
||||
name="basic_search_embedding",
|
||||
|
||||
@ -245,6 +245,7 @@ ignore = [
|
||||
# TODO RE-Enable when we get bandwidth
|
||||
"PERF203", # Needs restructuring of errors, we should bail-out on first error
|
||||
"C901", # needs refactoring to remove cyclomatic complexity
|
||||
"B008", # Needs to restructure our cli params with Typer into constants
|
||||
]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user