Task/remove dynamic retries (#1941)
Some checks are pending
gh-pages / build (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.11) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.10) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.11) (push) Waiting to run
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Publish (pypi) / Upload release to PyPI (push) Waiting to run
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Spellcheck / spellcheck (push) Waiting to run

* Remove max retries. Update Typer args

* Format

* Semver

* Fix typo

* Ruff and Typos

* Format
This commit is contained in:
Alonso Guevara 2025-05-20 11:48:27 -06:00 committed by GitHub
parent 36948b8d2e
commit 24018c6155
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 393 additions and 335 deletions

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Remove Dynamic Max Retries support. Refactor typer typing in cli interface"
}

View File

@ -17,7 +17,7 @@ import annotated_types
from pydantic import PositiveInt, validate_call
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.defaults import graphrag_config_defaults, language_model_defaults
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.language_model.manager import ModelManager
from graphrag.logger.base import ProgressLogger
@ -109,15 +109,6 @@ async def generate_indexing_prompts(
logger.info("Retrieving language model configuration...")
default_llm_settings = config.get_language_model_config(PROMPT_TUNING_MODEL_ID)
# if max_retries is not set, inject a dynamically assigned value based on the number of expected LLM calls
# to be made or fallback to a default value in the worst case
if default_llm_settings.max_retries < -1:
default_llm_settings.max_retries = min(
len(doc_list), language_model_defaults.max_retries
)
msg = f"max_retries not set, using default value: {default_llm_settings.max_retries}"
logger.warning(msg)
logger.info("Creating language model...")
llm = ModelManager().register_chat(
name="prompt_tuning",

View File

@ -7,7 +7,6 @@ import os
import re
from collections.abc import Callable
from pathlib import Path
from typing import Annotated
import typer
@ -78,25 +77,40 @@ def path_autocomplete(
return completer
CONFIG_AUTOCOMPLETE = path_autocomplete(
file_okay=True,
dir_okay=False,
match_wildcard="*.yaml",
readable=True,
)
ROOT_AUTOCOMPLETE = path_autocomplete(
file_okay=False,
dir_okay=True,
writable=True,
match_wildcard="*",
)
@app.command("init")
def _initialize_cli(
root: Annotated[
Path,
typer.Option(
help="The project root directory.",
dir_okay=True,
writable=True,
resolve_path=True,
autocompletion=path_autocomplete(
file_okay=False, dir_okay=True, writable=True, match_wildcard="*"
),
),
],
force: Annotated[
bool,
typer.Option(help="Force initialization even if the project already exists."),
] = False,
):
root: Path = typer.Option(
Path(),
"--root",
"-r",
help="The project root directory.",
dir_okay=True,
writable=True,
resolve_path=True,
autocompletion=ROOT_AUTOCOMPLETE,
),
force: bool = typer.Option(
False,
"--force",
"-f",
help="Force initialization even if the project already exists.",
),
) -> None:
"""Generate a default configuration file."""
from graphrag.cli.initialize import initialize_project_at
@ -105,60 +119,80 @@ def _initialize_cli(
@app.command("index")
def _index_cli(
config: Annotated[
Path | None,
typer.Option(
help="The configuration to use.", exists=True, file_okay=True, readable=True
config: Path | None = typer.Option(
None,
"--config",
"-c",
help="The configuration to use.",
exists=True,
file_okay=True,
readable=True,
autocompletion=CONFIG_AUTOCOMPLETE,
),
root: Path = typer.Option(
Path(),
"--root",
"-r",
help="The project root directory.",
exists=True,
dir_okay=True,
writable=True,
resolve_path=True,
autocompletion=ROOT_AUTOCOMPLETE,
),
method: IndexingMethod = typer.Option(
IndexingMethod.Standard.value,
"--method",
"-m",
help="The indexing method to use.",
),
verbose: bool = typer.Option(
False,
"--verbose",
"-v",
help="Run the indexing pipeline with verbose logging",
),
memprofile: bool = typer.Option(
False,
"--memprofile",
help="Run the indexing pipeline with memory profiling",
),
logger: LoggerType = typer.Option(
LoggerType.RICH.value,
"--logger",
help="The progress logger to use.",
),
dry_run: bool = typer.Option(
False,
"--dry-run",
help=(
"Run the indexing pipeline without executing any steps "
"to inspect and validate the configuration."
),
] = None,
root: Annotated[
Path,
typer.Option(
help="The project root directory.",
exists=True,
dir_okay=True,
writable=True,
resolve_path=True,
autocompletion=path_autocomplete(
file_okay=False, dir_okay=True, writable=True, match_wildcard="*"
),
),
cache: bool = typer.Option(
True,
"--cache/--no-cache",
help="Use LLM cache.",
),
skip_validation: bool = typer.Option(
False,
"--skip-validation",
help="Skip any preflight validation. Useful when running no LLM steps.",
),
output: Path | None = typer.Option(
None,
"--output",
"-o",
help=(
"Indexing pipeline output directory. "
"Overrides output.base_dir in the configuration file."
),
] = Path(), # set default to current directory
method: Annotated[
IndexingMethod, typer.Option(help="The indexing method to use.")
] = IndexingMethod.Standard,
verbose: Annotated[
bool, typer.Option(help="Run the indexing pipeline with verbose logging")
] = False,
memprofile: Annotated[
bool, typer.Option(help="Run the indexing pipeline with memory profiling")
] = False,
logger: Annotated[
LoggerType, typer.Option(help="The progress logger to use.")
] = LoggerType.RICH,
dry_run: Annotated[
bool,
typer.Option(
help="Run the indexing pipeline without executing any steps to inspect and validate the configuration."
),
] = False,
cache: Annotated[bool, typer.Option(help="Use LLM cache.")] = True,
skip_validation: Annotated[
bool,
typer.Option(
help="Skip any preflight validation. Useful when running no LLM steps."
),
] = False,
output: Annotated[
Path | None,
typer.Option(
help="Indexing pipeline output directory. Overrides output.base_dir in the configuration file.",
dir_okay=True,
writable=True,
resolve_path=True,
),
] = None,
):
dir_okay=True,
writable=True,
resolve_path=True,
),
) -> None:
"""Build a knowledge graph index."""
from graphrag.cli.index import index_cli
@ -178,51 +212,72 @@ def _index_cli(
@app.command("update")
def _update_cli(
config: Annotated[
Path | None,
typer.Option(
help="The configuration to use.", exists=True, file_okay=True, readable=True
config: Path | None = typer.Option(
None,
"--config",
"-c",
help="The configuration to use.",
exists=True,
file_okay=True,
readable=True,
autocompletion=CONFIG_AUTOCOMPLETE,
),
root: Path = typer.Option(
Path(),
"--root",
"-r",
help="The project root directory.",
exists=True,
dir_okay=True,
writable=True,
resolve_path=True,
autocompletion=ROOT_AUTOCOMPLETE,
),
method: IndexingMethod = typer.Option(
IndexingMethod.Standard.value,
"--method",
"-m",
help="The indexing method to use.",
),
verbose: bool = typer.Option(
False,
"--verbose",
"-v",
help="Run the indexing pipeline with verbose logging.",
),
memprofile: bool = typer.Option(
False,
"--memprofile",
help="Run the indexing pipeline with memory profiling.",
),
logger: LoggerType = typer.Option(
LoggerType.RICH.value,
"--logger",
help="The progress logger to use.",
),
cache: bool = typer.Option(
True,
"--cache/--no-cache",
help="Use LLM cache.",
),
skip_validation: bool = typer.Option(
False,
"--skip-validation",
help="Skip any preflight validation. Useful when running no LLM steps.",
),
output: Path | None = typer.Option(
None,
"--output",
"-o",
help=(
"Indexing pipeline output directory. "
"Overrides output.base_dir in the configuration file."
),
] = None,
root: Annotated[
Path,
typer.Option(
help="The project root directory.",
exists=True,
dir_okay=True,
writable=True,
resolve_path=True,
),
] = Path(), # set default to current directory
method: Annotated[
IndexingMethod, typer.Option(help="The indexing method to use.")
] = IndexingMethod.Standard,
verbose: Annotated[
bool, typer.Option(help="Run the indexing pipeline with verbose logging")
] = False,
memprofile: Annotated[
bool, typer.Option(help="Run the indexing pipeline with memory profiling")
] = False,
logger: Annotated[
LoggerType, typer.Option(help="The progress logger to use.")
] = LoggerType.RICH,
cache: Annotated[bool, typer.Option(help="Use LLM cache.")] = True,
skip_validation: Annotated[
bool,
typer.Option(
help="Skip any preflight validation. Useful when running no LLM steps."
),
] = False,
output: Annotated[
Path | None,
typer.Option(
help="Indexing pipeline output directory. Overrides output.base_dir in the configuration file.",
dir_okay=True,
writable=True,
resolve_path=True,
),
] = None,
):
dir_okay=True,
writable=True,
resolve_path=True,
),
) -> None:
"""
Update an existing knowledge graph index.
@ -245,104 +300,107 @@ def _update_cli(
@app.command("prompt-tune")
def _prompt_tune_cli(
root: Annotated[
Path,
typer.Option(
help="The project root directory.",
exists=True,
dir_okay=True,
writable=True,
resolve_path=True,
autocompletion=path_autocomplete(
file_okay=False, dir_okay=True, writable=True, match_wildcard="*"
),
root: Path = typer.Option(
Path(),
"--root",
"-r",
help="The project root directory.",
exists=True,
dir_okay=True,
writable=True,
resolve_path=True,
autocompletion=ROOT_AUTOCOMPLETE,
),
config: Path | None = typer.Option(
None,
"--config",
"-c",
help="The configuration to use.",
exists=True,
file_okay=True,
readable=True,
autocompletion=CONFIG_AUTOCOMPLETE,
),
verbose: bool = typer.Option(
False,
"--verbose",
"-v",
help="Run the prompt tuning pipeline with verbose logging.",
),
logger: LoggerType = typer.Option(
LoggerType.RICH.value,
"--logger",
help="The progress logger to use.",
),
domain: str | None = typer.Option(
None,
"--domain",
help=(
"The domain your input data is related to. "
"For example 'space science', 'microbiology', 'environmental news'. "
"If not defined, a domain will be inferred from the input data."
),
] = Path(), # set default to current directory
config: Annotated[
Path | None,
typer.Option(
help="The configuration to use.",
exists=True,
file_okay=True,
readable=True,
autocompletion=path_autocomplete(
file_okay=True, dir_okay=False, match_wildcard="*"
),
),
] = None,
verbose: Annotated[
bool, typer.Option(help="Run the prompt tuning pipeline with verbose logging")
] = False,
logger: Annotated[
LoggerType, typer.Option(help="The progress logger to use.")
] = LoggerType.RICH,
domain: Annotated[
str | None,
typer.Option(
help="The domain your input data is related to. For example 'space science', 'microbiology', 'environmental news'. If not defined, a domain will be inferred from the input data."
),
] = None,
selection_method: Annotated[
DocSelectionType, typer.Option(help="The text chunk selection method.")
] = DocSelectionType.RANDOM,
n_subset_max: Annotated[
int,
typer.Option(
help="The number of text chunks to embed when --selection-method=auto."
),
] = N_SUBSET_MAX,
k: Annotated[
int,
typer.Option(
help="The maximum number of documents to select from each centroid when --selection-method=auto."
),
] = K,
limit: Annotated[
int,
typer.Option(
help="The number of documents to load when --selection-method={random,top}."
),
] = LIMIT,
max_tokens: Annotated[
int, typer.Option(help="The max token count for prompt generation.")
] = MAX_TOKEN_COUNT,
min_examples_required: Annotated[
int,
typer.Option(
help="The minimum number of examples to generate/include in the entity extraction prompt."
),
] = 2,
chunk_size: Annotated[
int,
typer.Option(
help="The size of each example text chunk. Overrides chunks.size in the configuration file."
),
] = graphrag_config_defaults.chunks.size,
overlap: Annotated[
int,
typer.Option(
help="The overlap size for chunking documents. Overrides chunks.overlap in the configuration file"
),
] = graphrag_config_defaults.chunks.overlap,
language: Annotated[
str | None,
typer.Option(
help="The primary language used for inputs and outputs in graphrag prompts."
),
] = None,
discover_entity_types: Annotated[
bool, typer.Option(help="Discover and extract unspecified entity types.")
] = True,
output: Annotated[
Path,
typer.Option(
help="The directory to save prompts to, relative to the project root directory.",
dir_okay=True,
writable=True,
resolve_path=True,
),
] = Path("prompts"),
):
),
selection_method: DocSelectionType = typer.Option(
DocSelectionType.RANDOM.value,
"--selection-method",
help="The text chunk selection method.",
),
n_subset_max: int = typer.Option(
N_SUBSET_MAX,
"--n-subset-max",
help="The number of text chunks to embed when --selection-method=auto.",
),
k: int = typer.Option(
K,
"--k",
help="The maximum number of documents to select from each centroid when --selection-method=auto.",
),
limit: int = typer.Option(
LIMIT,
"--limit",
help="The number of documents to load when --selection-method={random,top}.",
),
max_tokens: int = typer.Option(
MAX_TOKEN_COUNT,
"--max-tokens",
help="The max token count for prompt generation.",
),
min_examples_required: int = typer.Option(
2,
"--min-examples-required",
help="The minimum number of examples to generate/include in the entity extraction prompt.",
),
chunk_size: int = typer.Option(
graphrag_config_defaults.chunks.size,
"--chunk-size",
help="The size of each example text chunk. Overrides chunks.size in the configuration file.",
),
overlap: int = typer.Option(
graphrag_config_defaults.chunks.overlap,
"--overlap",
help="The overlap size for chunking documents. Overrides chunks.overlap in the configuration file.",
),
language: str | None = typer.Option(
None,
"--language",
help="The primary language used for inputs and outputs in graphrag prompts.",
),
discover_entity_types: bool = typer.Option(
True,
"--discover-entity-types/--no-discover-entity-types",
help="Discover and extract unspecified entity types.",
),
output: Path = typer.Option(
Path("prompts"),
"--output",
"-o",
help="The directory to save prompts to, relative to the project root directory.",
dir_okay=True,
writable=True,
resolve_path=True,
),
) -> None:
"""Generate custom graphrag prompts with your own data (i.e. auto templating)."""
import asyncio
@ -373,66 +431,77 @@ def _prompt_tune_cli(
@app.command("query")
def _query_cli(
method: Annotated[SearchMethod, typer.Option(help="The query algorithm to use.")],
query: Annotated[str, typer.Option(help="The query to execute.")],
config: Annotated[
Path | None,
typer.Option(
help="The configuration to use.",
exists=True,
file_okay=True,
readable=True,
autocompletion=path_autocomplete(
file_okay=True, dir_okay=False, match_wildcard="*"
),
method: SearchMethod = typer.Option(
...,
"--method",
"-m",
help="The query algorithm to use.",
),
query: str = typer.Option(
...,
"--query",
"-q",
help="The query to execute.",
),
config: Path | None = typer.Option(
None,
"--config",
"-c",
help="The configuration to use.",
exists=True,
file_okay=True,
readable=True,
autocompletion=CONFIG_AUTOCOMPLETE,
),
data: Path | None = typer.Option(
None,
"--data",
"-d",
help="Index output directory (contains the parquet files).",
exists=True,
dir_okay=True,
readable=True,
resolve_path=True,
autocompletion=ROOT_AUTOCOMPLETE,
),
root: Path = typer.Option(
Path(),
"--root",
"-r",
help="The project root directory.",
exists=True,
dir_okay=True,
writable=True,
resolve_path=True,
autocompletion=ROOT_AUTOCOMPLETE,
),
community_level: int = typer.Option(
2,
"--community-level",
help=(
"Leiden hierarchy level from which to load community reports. "
"Higher values represent smaller communities."
),
] = None,
data: Annotated[
Path | None,
typer.Option(
help="Indexing pipeline output directory (i.e. contains the parquet files).",
exists=True,
dir_okay=True,
readable=True,
resolve_path=True,
autocompletion=path_autocomplete(
file_okay=False, dir_okay=True, match_wildcard="*"
),
),
dynamic_community_selection: bool = typer.Option(
False,
"--dynamic-community-selection/--no-dynamic-selection",
help="Use global search with dynamic community selection.",
),
response_type: str = typer.Option(
"Multiple Paragraphs",
"--response-type",
help=(
"Free-form description of the desired response format "
"(e.g. 'Single Sentence', 'List of 3-7 Points', etc.)."
),
] = None,
root: Annotated[
Path,
typer.Option(
help="The project root directory.",
exists=True,
dir_okay=True,
writable=True,
resolve_path=True,
autocompletion=path_autocomplete(
file_okay=False, dir_okay=True, match_wildcard="*"
),
),
] = Path(), # set default to current directory
community_level: Annotated[
int,
typer.Option(
help="The community level in the Leiden community hierarchy from which to load community reports. Higher values represent reports from smaller communities."
),
] = 2,
dynamic_community_selection: Annotated[
bool,
typer.Option(help="Use global search with dynamic community selection."),
] = False,
response_type: Annotated[
str,
typer.Option(
help="Free form text describing the response type and format, can be anything, e.g. Multiple Paragraphs, Single Paragraph, Single Sentence, List of 3-7 Points, Single Page, Multi-Page Report. Default: Multiple Paragraphs"
),
] = "Multiple Paragraphs",
streaming: Annotated[
bool, typer.Option(help="Print response in a streaming manner.")
] = False,
):
),
streaming: bool = typer.Option(
False,
"--streaming/--no-streaming",
help="Print the response in a streaming manner.",
),
) -> None:
"""Query a knowledge graph index."""
from graphrag.cli.query import (
run_basic_search,

View File

@ -33,7 +33,7 @@ models:
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed
async_mode: {language_model_defaults.async_mode.value} # or asyncio
retry_strategy: native
max_retries: -1 # set to -1 for dynamic retry logic (most optimal setting based on server response)
max_retries: {language_model_defaults.max_retries}
tokens_per_minute: {language_model_defaults.tokens_per_minute} # set to null to disable rate limiting
requests_per_minute: {language_model_defaults.requests_per_minute} # set to null to disable rate limiting
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
@ -51,7 +51,7 @@ models:
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed
async_mode: {language_model_defaults.async_mode.value} # or asyncio
retry_strategy: native
max_retries: -1 # set to -1 for dynamic retry logic (most optimal setting based on server response)
max_retries: {language_model_defaults.max_retries}
tokens_per_minute: {language_model_defaults.tokens_per_minute} # set to null to disable rate limiting
requests_per_minute: {language_model_defaults.requests_per_minute} # set to null to disable rate limiting

View File

@ -198,10 +198,38 @@ class LanguageModelConfig(BaseModel):
description="The number of tokens per minute to use for the LLM service.",
default=language_model_defaults.tokens_per_minute,
)
def _validate_tokens_per_minute(self) -> None:
"""Validate the tokens per minute.
Raises
------
ValueError
If the tokens per minute is less than 0.
"""
# If the value is a number, check if it is less than 1
if isinstance(self.tokens_per_minute, int) and self.tokens_per_minute < 1:
msg = f"Tokens per minute must be a non zero positive number, 'auto' or null. Suggested value: {language_model_defaults.tokens_per_minute}."
raise ValueError(msg)
requests_per_minute: int | Literal["auto"] | None = Field(
description="The number of requests per minute to use for the LLM service.",
default=language_model_defaults.requests_per_minute,
)
def _validate_requests_per_minute(self) -> None:
"""Validate the requests per minute.
Raises
------
ValueError
If the requests per minute is less than 0.
"""
# If the value is a number, check if it is less than 1
if isinstance(self.requests_per_minute, int) and self.requests_per_minute < 1:
msg = f"Requests per minute must be a non zero positive number, 'auto' or null. Suggested value: {language_model_defaults.requests_per_minute}."
raise ValueError(msg)
retry_strategy: str = Field(
description="The retry strategy to use for the LLM service.",
default=language_model_defaults.retry_strategy,
@ -210,6 +238,19 @@ class LanguageModelConfig(BaseModel):
description="The maximum number of retries to use for the LLM service.",
default=language_model_defaults.max_retries,
)
def _validate_max_retries(self) -> None:
"""Validate the maximum retries.
Raises
------
ValueError
If the maximum retries is less than 0.
"""
if self.max_retries < 1:
msg = f"Maximum retries must be greater than or equal to 1. Suggested value: {language_model_defaults.max_retries}."
raise ValueError(msg)
max_retry_wait: float = Field(
description="The maximum retry wait to use for the LLM service.",
default=language_model_defaults.max_retry_wait,
@ -279,6 +320,9 @@ class LanguageModelConfig(BaseModel):
self._validate_type()
self._validate_auth_type()
self._validate_api_key()
self._validate_tokens_per_minute()
self._validate_requests_per_minute()
self._validate_max_retries()
self._validate_azure_settings()
self._validate_encoding_model()
return self

View File

@ -109,10 +109,6 @@ async def _text_embed_with_vector_store(
strategy_exec = load_strategy(strategy_type)
strategy_config = {**strategy}
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
strategy_config["llm"]["max_retries"] = len(input)
# Get vector-storage configuration
insert_batch_size: int = (
vector_store_config.get("batch_size") or DEFAULT_EMBEDDING_BATCH_SIZE

View File

@ -50,10 +50,6 @@ async def extract_covariates(
strategy = strategy or {}
strategy_config = {**strategy}
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
strategy_config["llm"]["max_retries"] = len(input)
async def run_strategy(row):
text = row[column]
result = await run_extract_claims(

View File

@ -45,10 +45,6 @@ async def extract_graph(
)
strategy_config = {**strategy}
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
strategy_config["llm"]["max_retries"] = len(text_units)
num_started = 0
async def run_strategy(row):

View File

@ -45,10 +45,6 @@ async def summarize_communities(
strategy_exec = load_strategy(strategy["type"])
strategy_config = {**strategy}
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
strategy_config["llm"]["max_retries"] = len(nodes)
community_hierarchy = (
communities.explode("children")
.rename({"children": "sub_community"}, axis=1)

View File

@ -36,10 +36,6 @@ async def summarize_descriptions(
)
strategy_config = {**strategy}
# if max_retries is not set, inject a dynamically assigned value based on the maximum number of expected LLM calls to be made
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
strategy_config["llm"]["max_retries"] = len(entities_df) + len(relationships_df)
async def get_summarized(
nodes: pd.DataFrame, edges: pd.DataFrame, semaphore: asyncio.Semaphore
):

View File

@ -7,7 +7,6 @@ import asyncio
import sys
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.defaults import language_model_defaults
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.language_model.manager import ModelManager
from graphrag.logger.print_progress import ProgressLogger
@ -18,9 +17,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
# Validate Chat LLM configs
# TODO: Replace default_chat_model with a way to select the model
default_llm_settings = parameters.get_language_model_config("default_chat_model")
# if max_retries is not set, set it to the default value
if default_llm_settings.max_retries == -1:
default_llm_settings.max_retries = language_model_defaults.max_retries
llm = ModelManager().register_chat(
name="test-llm",
model_type=default_llm_settings.type,
@ -40,8 +37,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
embedding_llm_settings = parameters.get_language_model_config(
parameters.embed_text.model_id
)
if embedding_llm_settings.max_retries == -1:
embedding_llm_settings.max_retries = language_model_defaults.max_retries
embed_llm = ModelManager().register_embedding(
name="test-embed-llm",
model_type=embedding_llm_settings.type,

View File

@ -52,11 +52,6 @@ def get_local_search_engine(
"""Create a local search engine based on data + configuration."""
model_settings = config.get_language_model_config(config.local_search.chat_model_id)
if model_settings.max_retries == -1:
model_settings.max_retries = (
len(reports) + len(entities) + len(relationships) + len(covariates)
)
chat_model = ModelManager().get_or_create_chat_model(
name="local_search_chat",
model_type=model_settings.type,
@ -66,10 +61,7 @@ def get_local_search_engine(
embedding_settings = config.get_language_model_config(
config.local_search.embedding_model_id
)
if embedding_settings.max_retries == -1:
embedding_settings.max_retries = (
len(reports) + len(entities) + len(relationships)
)
embedding_model = ModelManager().get_or_create_embedding_model(
name="local_search_embedding",
model_type=embedding_settings.type,
@ -134,8 +126,6 @@ def get_global_search_engine(
config.global_search.chat_model_id
)
if model_settings.max_retries == -1:
model_settings.max_retries = len(reports) + len(entities)
model = ModelManager().get_or_create_chat_model(
name="global_search",
model_type=model_settings.type,
@ -220,13 +210,6 @@ def get_drift_search_engine(
config.drift_search.chat_model_id
)
if chat_model_settings.max_retries == -1:
chat_model_settings.max_retries = (
config.drift_search.drift_k_followups
* config.drift_search.primer_folds
* config.drift_search.n_depth
)
chat_model = ModelManager().get_or_create_chat_model(
name="drift_search_chat",
model_type=chat_model_settings.type,
@ -237,11 +220,6 @@ def get_drift_search_engine(
config.drift_search.embedding_model_id
)
if embedding_model_settings.max_retries == -1:
embedding_model_settings.max_retries = (
len(reports) + len(entities) + len(relationships)
)
embedding_model = ModelManager().get_or_create_embedding_model(
name="drift_search_embedding",
model_type=embedding_model_settings.type,
@ -283,9 +261,6 @@ def get_basic_search_engine(
config.basic_search.chat_model_id
)
if chat_model_settings.max_retries == -1:
chat_model_settings.max_retries = len(text_units)
chat_model = ModelManager().get_or_create_chat_model(
name="basic_search_chat",
model_type=chat_model_settings.type,
@ -295,8 +270,6 @@ def get_basic_search_engine(
embedding_model_settings = config.get_language_model_config(
config.basic_search.embedding_model_id
)
if embedding_model_settings.max_retries == -1:
embedding_model_settings.max_retries = len(text_units)
embedding_model = ModelManager().get_or_create_embedding_model(
name="basic_search_embedding",

View File

@ -245,6 +245,7 @@ ignore = [
# TODO RE-Enable when we get bandwidth
"PERF203", # Needs restructuring of errors, we should bail-out on first error
"C901", # needs refactoring to remove cyclomatic complexity
"B008", # Needs to restructure our cli params with Typer into constants
]
[tool.ruff.lint.per-file-ignores]