Init config cleanup (#2084)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled

* Spruce up init_config output, including LiteLLM default

* Remove deployment_name requirement for Azure

* Semver

* Add model_provider

* Add default model_provider

* Remove OBE test

* Update minimal config for tests

* Add model_provider to verb tests
This commit is contained in:
Nathan Evans 2025-10-06 12:06:41 -07:00 committed by GitHub
parent 2bd3922d8d
commit 6c86b0a7bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 49 additions and 64 deletions

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Set LiteLLM as default in init_content."
}

View File

@ -6,7 +6,7 @@
from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from typing import ClassVar, Literal
from typing import ClassVar
from graphrag.config.embeddings import default_embeddings
from graphrag.config.enums import (
@ -46,13 +46,14 @@ from graphrag.language_model.providers.litellm.services.retry.retry import Retry
DEFAULT_OUTPUT_BASE_DIR = "output"
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
DEFAULT_CHAT_MODEL_TYPE = ModelType.OpenAIChat
DEFAULT_CHAT_MODEL_TYPE = ModelType.Chat
DEFAULT_CHAT_MODEL = "gpt-4-turbo-preview"
DEFAULT_CHAT_MODEL_AUTH_TYPE = AuthType.APIKey
DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model"
DEFAULT_EMBEDDING_MODEL_TYPE = ModelType.OpenAIEmbedding
DEFAULT_EMBEDDING_MODEL_TYPE = ModelType.Embedding
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
DEFAULT_EMBEDDING_MODEL_AUTH_TYPE = AuthType.APIKey
DEFAULT_MODEL_PROVIDER = "openai"
DEFAULT_VECTOR_STORE_ID = "default_vector_store"
ENCODING_MODEL = "cl100k_base"
@ -325,10 +326,10 @@ class LanguageModelDefaults:
proxy: None = None
audience: None = None
model_supports_json: None = None
tokens_per_minute: Literal["auto"] = "auto"
requests_per_minute: Literal["auto"] = "auto"
tokens_per_minute: None = None
requests_per_minute: None = None
rate_limit_strategy: str | None = "static"
retry_strategy: str = "native"
retry_strategy: str = "exponential_backoff"
max_retries: int = 10
max_retry_wait: float = 10.0
concurrent_requests: int = 25

View File

@ -33,15 +33,6 @@ class AzureApiVersionMissingError(ValueError):
super().__init__(msg)
class AzureDeploymentNameMissingError(ValueError):
"""Azure Deployment Name missing error."""
def __init__(self, llm_type: str) -> None:
"""Init method definition."""
msg = f"Deployment name is required for {llm_type}. Please rerun `graphrag init` set the deployment_name."
super().__init__(msg)
class LanguageModelConfigMissingError(ValueError):
"""Missing model configuration error."""

View File

@ -19,41 +19,34 @@ INIT_YAML = f"""\
models:
{defs.DEFAULT_CHAT_MODEL_ID}:
type: {defs.DEFAULT_CHAT_MODEL_TYPE.value} # or azure_openai_chat
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-05-01-preview
type: {defs.DEFAULT_CHAT_MODEL_TYPE.value}
model_provider: {defs.DEFAULT_MODEL_PROVIDER}
auth_type: {defs.DEFAULT_CHAT_MODEL_AUTH_TYPE.value} # or azure_managed_identity
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
# audience: "https://cognitiveservices.azure.com/.default"
# organization: <organization_id>
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file, or remove if managed identity
model: {defs.DEFAULT_CHAT_MODEL}
# deployment_name: <azure_model_deployment_name>
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
model_supports_json: true # recommended if this is available for your model.
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed
async_mode: {language_model_defaults.async_mode.value} # or asyncio
retry_strategy: native
max_retries: {language_model_defaults.max_retries}
tokens_per_minute: {language_model_defaults.tokens_per_minute} # set to null to disable rate limiting
requests_per_minute: {language_model_defaults.requests_per_minute} # set to null to disable rate limiting
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
type: {defs.DEFAULT_EMBEDDING_MODEL_TYPE.value} # or azure_openai_embedding
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-05-01-preview
auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value} # or azure_managed_identity
api_key: ${{GRAPHRAG_API_KEY}}
# audience: "https://cognitiveservices.azure.com/.default"
# organization: <organization_id>
model: {defs.DEFAULT_EMBEDDING_MODEL}
# deployment_name: <azure_model_deployment_name>
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
model_supports_json: true # recommended if this is available for your model.
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed
concurrent_requests: {language_model_defaults.concurrent_requests}
async_mode: {language_model_defaults.async_mode.value} # or asyncio
retry_strategy: native
retry_strategy: {language_model_defaults.retry_strategy}
max_retries: {language_model_defaults.max_retries}
tokens_per_minute: null # set to null to disable rate limiting or auto for dynamic
requests_per_minute: null # set to null to disable rate limiting or auto for dynamic
tokens_per_minute: null
requests_per_minute: null
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
type: {defs.DEFAULT_EMBEDDING_MODEL_TYPE.value}
model_provider: {defs.DEFAULT_MODEL_PROVIDER}
auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value}
api_key: ${{GRAPHRAG_API_KEY}}
model: {defs.DEFAULT_EMBEDDING_MODEL}
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-05-01-preview
concurrent_requests: {language_model_defaults.concurrent_requests}
async_mode: {language_model_defaults.async_mode.value} # or asyncio
retry_strategy: {language_model_defaults.retry_strategy}
max_retries: {language_model_defaults.max_retries}
tokens_per_minute: null
requests_per_minute: null
### Input settings ###
@ -62,7 +55,6 @@ input:
type: {graphrag_config_defaults.input.storage.type.value} # or blob
base_dir: "{graphrag_config_defaults.input.storage.base_dir}"
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
chunks:
size: {graphrag_config_defaults.chunks.size}
@ -90,7 +82,6 @@ vector_store:
type: {vector_store_defaults.type}
db_uri: {vector_store_defaults.db_uri}
container_name: {vector_store_defaults.container_name}
overwrite: {vector_store_defaults.overwrite}
### Workflow settings ###

View File

@ -3,6 +3,7 @@
"""Language model configuration."""
import logging
from typing import Literal
import tiktoken
@ -14,11 +15,12 @@ from graphrag.config.errors import (
ApiKeyMissingError,
AzureApiBaseMissingError,
AzureApiVersionMissingError,
AzureDeploymentNameMissingError,
ConflictingSettingsError,
)
from graphrag.language_model.factory import ModelFactory
logger = logging.getLogger(__name__)
class LanguageModelConfig(BaseModel):
"""Language model configuration."""
@ -214,7 +216,8 @@ class LanguageModelConfig(BaseModel):
or self.type == ModelType.AzureOpenAIEmbedding
or self.model_provider == "azure" # indicates Litellm + AOI
) and (self.deployment_name is None or self.deployment_name.strip() == ""):
raise AzureDeploymentNameMissingError(self.type)
msg = f"deployment_name is not set for Azure-hosted model. This will default to your model name ({self.model}). If different, this should be set."
logger.debug(msg)
organization: str | None = Field(
description="The organization to use for the LLM service.",

View File

@ -1,9 +1,11 @@
models:
default_chat_model:
api_key: ${CUSTOM_API_KEY}
type: openai_chat
type: chat
model_provider: openai
model: gpt-4-turbo-preview
default_embedding_model:
api_key: ${CUSTOM_API_KEY}
type: openai_embedding
type: embedding
model_provider: openai
model: text-embedding-3-small

View File

@ -1,9 +1,11 @@
models:
default_chat_model:
api_key: ${SOME_NON_EXISTENT_ENV_VAR}
type: openai_chat
type: chat
model_provider: openai
model: gpt-4-turbo-preview
default_embedding_model:
api_key: ${SOME_NON_EXISTENT_ENV_VAR}
type: openai_embedding
type: embedding
model_provider: openai
model: text-embedding-3-small

View File

@ -133,19 +133,6 @@ def test_missing_azure_api_version() -> None:
})
def test_missing_azure_deployment_name() -> None:
missing_deployment_name_config = base_azure_model_config.copy()
del missing_deployment_name_config["deployment_name"]
with pytest.raises(ValidationError):
create_graphrag_config({
"models": {
defs.DEFAULT_CHAT_MODEL_ID: missing_deployment_name_config,
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
}
})
def test_default_config() -> None:
expected = get_default_graphrag_config()
actual = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})

View File

@ -41,12 +41,14 @@ DEFAULT_CHAT_MODEL_CONFIG = {
"api_key": FAKE_API_KEY,
"type": defs.DEFAULT_CHAT_MODEL_TYPE.value,
"model": defs.DEFAULT_CHAT_MODEL,
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
}
DEFAULT_EMBEDDING_MODEL_CONFIG = {
"api_key": FAKE_API_KEY,
"type": defs.DEFAULT_EMBEDDING_MODEL_TYPE.value,
"model": defs.DEFAULT_EMBEDDING_MODEL,
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
}
DEFAULT_MODEL_CONFIG = {

View File

@ -17,12 +17,14 @@ DEFAULT_CHAT_MODEL_CONFIG = {
"api_key": FAKE_API_KEY,
"type": defs.DEFAULT_CHAT_MODEL_TYPE.value,
"model": defs.DEFAULT_CHAT_MODEL,
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
}
DEFAULT_EMBEDDING_MODEL_CONFIG = {
"api_key": FAKE_API_KEY,
"type": defs.DEFAULT_EMBEDDING_MODEL_TYPE.value,
"model": defs.DEFAULT_EMBEDDING_MODEL,
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
}
DEFAULT_MODEL_CONFIG = {