Init config cleanup (#2084)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled

* Spruce up init_config output, including LiteLLM default

* Remove deployment_name requirement for Azure

* Semver

* Add model_provider

* Add default model_provider

* Remove OBE test

* Update minimal config for tests

* Add model_provider to verb tests
This commit is contained in:
Nathan Evans 2025-10-06 12:06:41 -07:00 committed by GitHub
parent 2bd3922d8d
commit 6c86b0a7bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 49 additions and 64 deletions

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Set LiteLLM as default in init_content."
}

View File

@ -6,7 +6,7 @@
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import ClassVar, Literal from typing import ClassVar
from graphrag.config.embeddings import default_embeddings from graphrag.config.embeddings import default_embeddings
from graphrag.config.enums import ( from graphrag.config.enums import (
@ -46,13 +46,14 @@ from graphrag.language_model.providers.litellm.services.retry.retry import Retry
DEFAULT_OUTPUT_BASE_DIR = "output" DEFAULT_OUTPUT_BASE_DIR = "output"
DEFAULT_CHAT_MODEL_ID = "default_chat_model" DEFAULT_CHAT_MODEL_ID = "default_chat_model"
DEFAULT_CHAT_MODEL_TYPE = ModelType.OpenAIChat DEFAULT_CHAT_MODEL_TYPE = ModelType.Chat
DEFAULT_CHAT_MODEL = "gpt-4-turbo-preview" DEFAULT_CHAT_MODEL = "gpt-4-turbo-preview"
DEFAULT_CHAT_MODEL_AUTH_TYPE = AuthType.APIKey DEFAULT_CHAT_MODEL_AUTH_TYPE = AuthType.APIKey
DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model" DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model"
DEFAULT_EMBEDDING_MODEL_TYPE = ModelType.OpenAIEmbedding DEFAULT_EMBEDDING_MODEL_TYPE = ModelType.Embedding
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small" DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
DEFAULT_EMBEDDING_MODEL_AUTH_TYPE = AuthType.APIKey DEFAULT_EMBEDDING_MODEL_AUTH_TYPE = AuthType.APIKey
DEFAULT_MODEL_PROVIDER = "openai"
DEFAULT_VECTOR_STORE_ID = "default_vector_store" DEFAULT_VECTOR_STORE_ID = "default_vector_store"
ENCODING_MODEL = "cl100k_base" ENCODING_MODEL = "cl100k_base"
@ -325,10 +326,10 @@ class LanguageModelDefaults:
proxy: None = None proxy: None = None
audience: None = None audience: None = None
model_supports_json: None = None model_supports_json: None = None
tokens_per_minute: Literal["auto"] = "auto" tokens_per_minute: None = None
requests_per_minute: Literal["auto"] = "auto" requests_per_minute: None = None
rate_limit_strategy: str | None = "static" rate_limit_strategy: str | None = "static"
retry_strategy: str = "native" retry_strategy: str = "exponential_backoff"
max_retries: int = 10 max_retries: int = 10
max_retry_wait: float = 10.0 max_retry_wait: float = 10.0
concurrent_requests: int = 25 concurrent_requests: int = 25

View File

@ -33,15 +33,6 @@ class AzureApiVersionMissingError(ValueError):
super().__init__(msg) super().__init__(msg)
class AzureDeploymentNameMissingError(ValueError):
"""Azure Deployment Name missing error."""
def __init__(self, llm_type: str) -> None:
"""Init method definition."""
msg = f"Deployment name is required for {llm_type}. Please rerun `graphrag init` set the deployment_name."
super().__init__(msg)
class LanguageModelConfigMissingError(ValueError): class LanguageModelConfigMissingError(ValueError):
"""Missing model configuration error.""" """Missing model configuration error."""

View File

@ -19,41 +19,34 @@ INIT_YAML = f"""\
models: models:
{defs.DEFAULT_CHAT_MODEL_ID}: {defs.DEFAULT_CHAT_MODEL_ID}:
type: {defs.DEFAULT_CHAT_MODEL_TYPE.value} # or azure_openai_chat type: {defs.DEFAULT_CHAT_MODEL_TYPE.value}
# api_base: https://<instance>.openai.azure.com model_provider: {defs.DEFAULT_MODEL_PROVIDER}
# api_version: 2024-05-01-preview
auth_type: {defs.DEFAULT_CHAT_MODEL_AUTH_TYPE.value} # or azure_managed_identity auth_type: {defs.DEFAULT_CHAT_MODEL_AUTH_TYPE.value} # or azure_managed_identity
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file, or remove if managed identity
# audience: "https://cognitiveservices.azure.com/.default"
# organization: <organization_id>
model: {defs.DEFAULT_CHAT_MODEL} model: {defs.DEFAULT_CHAT_MODEL}
# deployment_name: <azure_model_deployment_name>
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
model_supports_json: true # recommended if this is available for your model.
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed
async_mode: {language_model_defaults.async_mode.value} # or asyncio
retry_strategy: native
max_retries: {language_model_defaults.max_retries}
tokens_per_minute: {language_model_defaults.tokens_per_minute} # set to null to disable rate limiting
requests_per_minute: {language_model_defaults.requests_per_minute} # set to null to disable rate limiting
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
type: {defs.DEFAULT_EMBEDDING_MODEL_TYPE.value} # or azure_openai_embedding
# api_base: https://<instance>.openai.azure.com # api_base: https://<instance>.openai.azure.com
# api_version: 2024-05-01-preview # api_version: 2024-05-01-preview
auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value} # or azure_managed_identity
api_key: ${{GRAPHRAG_API_KEY}}
# audience: "https://cognitiveservices.azure.com/.default"
# organization: <organization_id>
model: {defs.DEFAULT_EMBEDDING_MODEL}
# deployment_name: <azure_model_deployment_name>
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
model_supports_json: true # recommended if this is available for your model. model_supports_json: true # recommended if this is available for your model.
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed concurrent_requests: {language_model_defaults.concurrent_requests}
async_mode: {language_model_defaults.async_mode.value} # or asyncio async_mode: {language_model_defaults.async_mode.value} # or asyncio
retry_strategy: native retry_strategy: {language_model_defaults.retry_strategy}
max_retries: {language_model_defaults.max_retries} max_retries: {language_model_defaults.max_retries}
tokens_per_minute: null # set to null to disable rate limiting or auto for dynamic tokens_per_minute: null
requests_per_minute: null # set to null to disable rate limiting or auto for dynamic requests_per_minute: null
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
type: {defs.DEFAULT_EMBEDDING_MODEL_TYPE.value}
model_provider: {defs.DEFAULT_MODEL_PROVIDER}
auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value}
api_key: ${{GRAPHRAG_API_KEY}}
model: {defs.DEFAULT_EMBEDDING_MODEL}
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-05-01-preview
concurrent_requests: {language_model_defaults.concurrent_requests}
async_mode: {language_model_defaults.async_mode.value} # or asyncio
retry_strategy: {language_model_defaults.retry_strategy}
max_retries: {language_model_defaults.max_retries}
tokens_per_minute: null
requests_per_minute: null
### Input settings ### ### Input settings ###
@ -62,7 +55,6 @@ input:
type: {graphrag_config_defaults.input.storage.type.value} # or blob type: {graphrag_config_defaults.input.storage.type.value} # or blob
base_dir: "{graphrag_config_defaults.input.storage.base_dir}" base_dir: "{graphrag_config_defaults.input.storage.base_dir}"
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json] file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
chunks: chunks:
size: {graphrag_config_defaults.chunks.size} size: {graphrag_config_defaults.chunks.size}
@ -90,7 +82,6 @@ vector_store:
type: {vector_store_defaults.type} type: {vector_store_defaults.type}
db_uri: {vector_store_defaults.db_uri} db_uri: {vector_store_defaults.db_uri}
container_name: {vector_store_defaults.container_name} container_name: {vector_store_defaults.container_name}
overwrite: {vector_store_defaults.overwrite}
### Workflow settings ### ### Workflow settings ###

View File

@ -3,6 +3,7 @@
"""Language model configuration.""" """Language model configuration."""
import logging
from typing import Literal from typing import Literal
import tiktoken import tiktoken
@ -14,11 +15,12 @@ from graphrag.config.errors import (
ApiKeyMissingError, ApiKeyMissingError,
AzureApiBaseMissingError, AzureApiBaseMissingError,
AzureApiVersionMissingError, AzureApiVersionMissingError,
AzureDeploymentNameMissingError,
ConflictingSettingsError, ConflictingSettingsError,
) )
from graphrag.language_model.factory import ModelFactory from graphrag.language_model.factory import ModelFactory
logger = logging.getLogger(__name__)
class LanguageModelConfig(BaseModel): class LanguageModelConfig(BaseModel):
"""Language model configuration.""" """Language model configuration."""
@ -214,7 +216,8 @@ class LanguageModelConfig(BaseModel):
or self.type == ModelType.AzureOpenAIEmbedding or self.type == ModelType.AzureOpenAIEmbedding
or self.model_provider == "azure" # indicates Litellm + AOI or self.model_provider == "azure" # indicates Litellm + AOI
) and (self.deployment_name is None or self.deployment_name.strip() == ""): ) and (self.deployment_name is None or self.deployment_name.strip() == ""):
raise AzureDeploymentNameMissingError(self.type) msg = f"deployment_name is not set for Azure-hosted model. This will default to your model name ({self.model}). If different, this should be set."
logger.debug(msg)
organization: str | None = Field( organization: str | None = Field(
description="The organization to use for the LLM service.", description="The organization to use for the LLM service.",

View File

@ -1,9 +1,11 @@
models: models:
default_chat_model: default_chat_model:
api_key: ${CUSTOM_API_KEY} api_key: ${CUSTOM_API_KEY}
type: openai_chat type: chat
model_provider: openai
model: gpt-4-turbo-preview model: gpt-4-turbo-preview
default_embedding_model: default_embedding_model:
api_key: ${CUSTOM_API_KEY} api_key: ${CUSTOM_API_KEY}
type: openai_embedding type: embedding
model_provider: openai
model: text-embedding-3-small model: text-embedding-3-small

View File

@ -1,9 +1,11 @@
models: models:
default_chat_model: default_chat_model:
api_key: ${SOME_NON_EXISTENT_ENV_VAR} api_key: ${SOME_NON_EXISTENT_ENV_VAR}
type: openai_chat type: chat
model_provider: openai
model: gpt-4-turbo-preview model: gpt-4-turbo-preview
default_embedding_model: default_embedding_model:
api_key: ${SOME_NON_EXISTENT_ENV_VAR} api_key: ${SOME_NON_EXISTENT_ENV_VAR}
type: openai_embedding type: embedding
model_provider: openai
model: text-embedding-3-small model: text-embedding-3-small

View File

@ -133,19 +133,6 @@ def test_missing_azure_api_version() -> None:
}) })
def test_missing_azure_deployment_name() -> None:
missing_deployment_name_config = base_azure_model_config.copy()
del missing_deployment_name_config["deployment_name"]
with pytest.raises(ValidationError):
create_graphrag_config({
"models": {
defs.DEFAULT_CHAT_MODEL_ID: missing_deployment_name_config,
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
}
})
def test_default_config() -> None: def test_default_config() -> None:
expected = get_default_graphrag_config() expected = get_default_graphrag_config()
actual = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) actual = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})

View File

@ -41,12 +41,14 @@ DEFAULT_CHAT_MODEL_CONFIG = {
"api_key": FAKE_API_KEY, "api_key": FAKE_API_KEY,
"type": defs.DEFAULT_CHAT_MODEL_TYPE.value, "type": defs.DEFAULT_CHAT_MODEL_TYPE.value,
"model": defs.DEFAULT_CHAT_MODEL, "model": defs.DEFAULT_CHAT_MODEL,
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
} }
DEFAULT_EMBEDDING_MODEL_CONFIG = { DEFAULT_EMBEDDING_MODEL_CONFIG = {
"api_key": FAKE_API_KEY, "api_key": FAKE_API_KEY,
"type": defs.DEFAULT_EMBEDDING_MODEL_TYPE.value, "type": defs.DEFAULT_EMBEDDING_MODEL_TYPE.value,
"model": defs.DEFAULT_EMBEDDING_MODEL, "model": defs.DEFAULT_EMBEDDING_MODEL,
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
} }
DEFAULT_MODEL_CONFIG = { DEFAULT_MODEL_CONFIG = {

View File

@ -17,12 +17,14 @@ DEFAULT_CHAT_MODEL_CONFIG = {
"api_key": FAKE_API_KEY, "api_key": FAKE_API_KEY,
"type": defs.DEFAULT_CHAT_MODEL_TYPE.value, "type": defs.DEFAULT_CHAT_MODEL_TYPE.value,
"model": defs.DEFAULT_CHAT_MODEL, "model": defs.DEFAULT_CHAT_MODEL,
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
} }
DEFAULT_EMBEDDING_MODEL_CONFIG = { DEFAULT_EMBEDDING_MODEL_CONFIG = {
"api_key": FAKE_API_KEY, "api_key": FAKE_API_KEY,
"type": defs.DEFAULT_EMBEDDING_MODEL_TYPE.value, "type": defs.DEFAULT_EMBEDDING_MODEL_TYPE.value,
"model": defs.DEFAULT_EMBEDDING_MODEL, "model": defs.DEFAULT_EMBEDDING_MODEL,
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
} }
DEFAULT_MODEL_CONFIG = { DEFAULT_MODEL_CONFIG = {