From 6c86b0a7bbc2084e096b5bcffc06045ae4a4eafe Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Mon, 6 Oct 2025 12:06:41 -0700 Subject: [PATCH] Init config cleanup (#2084) * Spruce up init_config output, including LiteLLM default * Remove deployment_name requirement for Azure * Semver * Add model_provider * Add default model_provider * Remove OBE test * Update minimal config for tests * Add model_provider to verb tests --- .../minor-20251003221030515836.json | 4 ++ graphrag/config/defaults.py | 13 ++--- graphrag/config/errors.py | 9 ---- graphrag/config/init_content.py | 51 ++++++++----------- .../config/models/language_model_config.py | 7 ++- .../fixtures/minimal_config/settings.yaml | 6 ++- .../settings.yaml | 6 ++- tests/unit/config/test_config.py | 13 ----- tests/unit/config/utils.py | 2 + tests/verbs/util.py | 2 + 10 files changed, 49 insertions(+), 64 deletions(-) create mode 100644 .semversioner/next-release/minor-20251003221030515836.json diff --git a/.semversioner/next-release/minor-20251003221030515836.json b/.semversioner/next-release/minor-20251003221030515836.json new file mode 100644 index 00000000..a3c39ccc --- /dev/null +++ b/.semversioner/next-release/minor-20251003221030515836.json @@ -0,0 +1,4 @@ +{ + "type": "minor", + "description": "Set LiteLLM as default in init_content." +} diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index 7acab69b..6ec2e2d2 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -6,7 +6,7 @@ from collections.abc import Callable from dataclasses import dataclass, field from pathlib import Path -from typing import ClassVar, Literal +from typing import ClassVar from graphrag.config.embeddings import default_embeddings from graphrag.config.enums import ( @@ -46,13 +46,14 @@ from graphrag.language_model.providers.litellm.services.retry.retry import Retry DEFAULT_OUTPUT_BASE_DIR = "output" DEFAULT_CHAT_MODEL_ID = "default_chat_model" -DEFAULT_CHAT_MODEL_TYPE = ModelType.OpenAIChat +DEFAULT_CHAT_MODEL_TYPE = ModelType.Chat DEFAULT_CHAT_MODEL = "gpt-4-turbo-preview" DEFAULT_CHAT_MODEL_AUTH_TYPE = AuthType.APIKey DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model" -DEFAULT_EMBEDDING_MODEL_TYPE = ModelType.OpenAIEmbedding +DEFAULT_EMBEDDING_MODEL_TYPE = ModelType.Embedding DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small" DEFAULT_EMBEDDING_MODEL_AUTH_TYPE = AuthType.APIKey +DEFAULT_MODEL_PROVIDER = "openai" DEFAULT_VECTOR_STORE_ID = "default_vector_store" ENCODING_MODEL = "cl100k_base" @@ -325,10 +326,10 @@ class LanguageModelDefaults: proxy: None = None audience: None = None model_supports_json: None = None - tokens_per_minute: Literal["auto"] = "auto" - requests_per_minute: Literal["auto"] = "auto" + tokens_per_minute: None = None + requests_per_minute: None = None rate_limit_strategy: str | None = "static" - retry_strategy: str = "native" + retry_strategy: str = "exponential_backoff" max_retries: int = 10 max_retry_wait: float = 10.0 concurrent_requests: int = 25 diff --git a/graphrag/config/errors.py b/graphrag/config/errors.py index 704a3373..32a20b83 100644 --- a/graphrag/config/errors.py +++ b/graphrag/config/errors.py @@ -33,15 +33,6 @@ class AzureApiVersionMissingError(ValueError): super().__init__(msg) -class AzureDeploymentNameMissingError(ValueError): - """Azure Deployment Name missing error.""" - - def __init__(self, llm_type: str) -> None: - """Init method definition.""" - msg = f"Deployment name is required for {llm_type}. Please rerun `graphrag init` set the deployment_name." - super().__init__(msg) - - class LanguageModelConfigMissingError(ValueError): """Missing model configuration error.""" diff --git a/graphrag/config/init_content.py b/graphrag/config/init_content.py index 1eb60cf1..fd16f20d 100644 --- a/graphrag/config/init_content.py +++ b/graphrag/config/init_content.py @@ -19,41 +19,34 @@ INIT_YAML = f"""\ models: {defs.DEFAULT_CHAT_MODEL_ID}: - type: {defs.DEFAULT_CHAT_MODEL_TYPE.value} # or azure_openai_chat - # api_base: https://.openai.azure.com - # api_version: 2024-05-01-preview + type: {defs.DEFAULT_CHAT_MODEL_TYPE.value} + model_provider: {defs.DEFAULT_MODEL_PROVIDER} auth_type: {defs.DEFAULT_CHAT_MODEL_AUTH_TYPE.value} # or azure_managed_identity - api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file - # audience: "https://cognitiveservices.azure.com/.default" - # organization: + api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file, or remove if managed identity model: {defs.DEFAULT_CHAT_MODEL} - # deployment_name: - # encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined - model_supports_json: true # recommended if this is available for your model. - concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed - async_mode: {language_model_defaults.async_mode.value} # or asyncio - retry_strategy: native - max_retries: {language_model_defaults.max_retries} - tokens_per_minute: {language_model_defaults.tokens_per_minute} # set to null to disable rate limiting - requests_per_minute: {language_model_defaults.requests_per_minute} # set to null to disable rate limiting - {defs.DEFAULT_EMBEDDING_MODEL_ID}: - type: {defs.DEFAULT_EMBEDDING_MODEL_TYPE.value} # or azure_openai_embedding # api_base: https://.openai.azure.com # api_version: 2024-05-01-preview - auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value} # or azure_managed_identity - api_key: ${{GRAPHRAG_API_KEY}} - # audience: "https://cognitiveservices.azure.com/.default" - # organization: - model: {defs.DEFAULT_EMBEDDING_MODEL} - # deployment_name: - # encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined model_supports_json: true # recommended if this is available for your model. - concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed + concurrent_requests: {language_model_defaults.concurrent_requests} async_mode: {language_model_defaults.async_mode.value} # or asyncio - retry_strategy: native + retry_strategy: {language_model_defaults.retry_strategy} max_retries: {language_model_defaults.max_retries} - tokens_per_minute: null # set to null to disable rate limiting or auto for dynamic - requests_per_minute: null # set to null to disable rate limiting or auto for dynamic + tokens_per_minute: null + requests_per_minute: null + {defs.DEFAULT_EMBEDDING_MODEL_ID}: + type: {defs.DEFAULT_EMBEDDING_MODEL_TYPE.value} + model_provider: {defs.DEFAULT_MODEL_PROVIDER} + auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value} + api_key: ${{GRAPHRAG_API_KEY}} + model: {defs.DEFAULT_EMBEDDING_MODEL} + # api_base: https://.openai.azure.com + # api_version: 2024-05-01-preview + concurrent_requests: {language_model_defaults.concurrent_requests} + async_mode: {language_model_defaults.async_mode.value} # or asyncio + retry_strategy: {language_model_defaults.retry_strategy} + max_retries: {language_model_defaults.max_retries} + tokens_per_minute: null + requests_per_minute: null ### Input settings ### @@ -62,7 +55,6 @@ input: type: {graphrag_config_defaults.input.storage.type.value} # or blob base_dir: "{graphrag_config_defaults.input.storage.base_dir}" file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json] - chunks: size: {graphrag_config_defaults.chunks.size} @@ -90,7 +82,6 @@ vector_store: type: {vector_store_defaults.type} db_uri: {vector_store_defaults.db_uri} container_name: {vector_store_defaults.container_name} - overwrite: {vector_store_defaults.overwrite} ### Workflow settings ### diff --git a/graphrag/config/models/language_model_config.py b/graphrag/config/models/language_model_config.py index 12d08bb5..05f10670 100644 --- a/graphrag/config/models/language_model_config.py +++ b/graphrag/config/models/language_model_config.py @@ -3,6 +3,7 @@ """Language model configuration.""" +import logging from typing import Literal import tiktoken @@ -14,11 +15,12 @@ from graphrag.config.errors import ( ApiKeyMissingError, AzureApiBaseMissingError, AzureApiVersionMissingError, - AzureDeploymentNameMissingError, ConflictingSettingsError, ) from graphrag.language_model.factory import ModelFactory +logger = logging.getLogger(__name__) + class LanguageModelConfig(BaseModel): """Language model configuration.""" @@ -214,7 +216,8 @@ class LanguageModelConfig(BaseModel): or self.type == ModelType.AzureOpenAIEmbedding or self.model_provider == "azure" # indicates Litellm + AOI ) and (self.deployment_name is None or self.deployment_name.strip() == ""): - raise AzureDeploymentNameMissingError(self.type) + msg = f"deployment_name is not set for Azure-hosted model. This will default to your model name ({self.model}). If different, this should be set." + logger.debug(msg) organization: str | None = Field( description="The organization to use for the LLM service.", diff --git a/tests/unit/config/fixtures/minimal_config/settings.yaml b/tests/unit/config/fixtures/minimal_config/settings.yaml index 2fec50a3..82183a45 100644 --- a/tests/unit/config/fixtures/minimal_config/settings.yaml +++ b/tests/unit/config/fixtures/minimal_config/settings.yaml @@ -1,9 +1,11 @@ models: default_chat_model: api_key: ${CUSTOM_API_KEY} - type: openai_chat + type: chat + model_provider: openai model: gpt-4-turbo-preview default_embedding_model: api_key: ${CUSTOM_API_KEY} - type: openai_embedding + type: embedding + model_provider: openai model: text-embedding-3-small \ No newline at end of file diff --git a/tests/unit/config/fixtures/minimal_config_missing_env_var/settings.yaml b/tests/unit/config/fixtures/minimal_config_missing_env_var/settings.yaml index 9e1c4571..651b997a 100644 --- a/tests/unit/config/fixtures/minimal_config_missing_env_var/settings.yaml +++ b/tests/unit/config/fixtures/minimal_config_missing_env_var/settings.yaml @@ -1,9 +1,11 @@ models: default_chat_model: api_key: ${SOME_NON_EXISTENT_ENV_VAR} - type: openai_chat + type: chat + model_provider: openai model: gpt-4-turbo-preview default_embedding_model: api_key: ${SOME_NON_EXISTENT_ENV_VAR} - type: openai_embedding + type: embedding + model_provider: openai model: text-embedding-3-small \ No newline at end of file diff --git a/tests/unit/config/test_config.py b/tests/unit/config/test_config.py index 0b20fb70..a741de7c 100644 --- a/tests/unit/config/test_config.py +++ b/tests/unit/config/test_config.py @@ -133,19 +133,6 @@ def test_missing_azure_api_version() -> None: }) -def test_missing_azure_deployment_name() -> None: - missing_deployment_name_config = base_azure_model_config.copy() - del missing_deployment_name_config["deployment_name"] - - with pytest.raises(ValidationError): - create_graphrag_config({ - "models": { - defs.DEFAULT_CHAT_MODEL_ID: missing_deployment_name_config, - defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG, - } - }) - - def test_default_config() -> None: expected = get_default_graphrag_config() actual = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index 2fa5b141..ab988a50 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -41,12 +41,14 @@ DEFAULT_CHAT_MODEL_CONFIG = { "api_key": FAKE_API_KEY, "type": defs.DEFAULT_CHAT_MODEL_TYPE.value, "model": defs.DEFAULT_CHAT_MODEL, + "model_provider": defs.DEFAULT_MODEL_PROVIDER, } DEFAULT_EMBEDDING_MODEL_CONFIG = { "api_key": FAKE_API_KEY, "type": defs.DEFAULT_EMBEDDING_MODEL_TYPE.value, "model": defs.DEFAULT_EMBEDDING_MODEL, + "model_provider": defs.DEFAULT_MODEL_PROVIDER, } DEFAULT_MODEL_CONFIG = { diff --git a/tests/verbs/util.py b/tests/verbs/util.py index 8d342b47..802ab910 100644 --- a/tests/verbs/util.py +++ b/tests/verbs/util.py @@ -17,12 +17,14 @@ DEFAULT_CHAT_MODEL_CONFIG = { "api_key": FAKE_API_KEY, "type": defs.DEFAULT_CHAT_MODEL_TYPE.value, "model": defs.DEFAULT_CHAT_MODEL, + "model_provider": defs.DEFAULT_MODEL_PROVIDER, } DEFAULT_EMBEDDING_MODEL_CONFIG = { "api_key": FAKE_API_KEY, "type": defs.DEFAULT_EMBEDDING_MODEL_TYPE.value, "model": defs.DEFAULT_EMBEDDING_MODEL, + "model_provider": defs.DEFAULT_MODEL_PROVIDER, } DEFAULT_MODEL_CONFIG = {