mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
Init config cleanup (#2084)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
* Spruce up init_config output, including LiteLLM default * Remove deployment_name requirement for Azure * Semver * Add model_provider * Add default model_provider * Remove OBE test * Update minimal config for tests * Add model_provider to verb tests
This commit is contained in:
parent
2bd3922d8d
commit
6c86b0a7bb
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "minor",
|
||||||
|
"description": "Set LiteLLM as default in init_content."
|
||||||
|
}
|
||||||
@ -6,7 +6,7 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import ClassVar, Literal
|
from typing import ClassVar
|
||||||
|
|
||||||
from graphrag.config.embeddings import default_embeddings
|
from graphrag.config.embeddings import default_embeddings
|
||||||
from graphrag.config.enums import (
|
from graphrag.config.enums import (
|
||||||
@ -46,13 +46,14 @@ from graphrag.language_model.providers.litellm.services.retry.retry import Retry
|
|||||||
|
|
||||||
DEFAULT_OUTPUT_BASE_DIR = "output"
|
DEFAULT_OUTPUT_BASE_DIR = "output"
|
||||||
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
|
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
|
||||||
DEFAULT_CHAT_MODEL_TYPE = ModelType.OpenAIChat
|
DEFAULT_CHAT_MODEL_TYPE = ModelType.Chat
|
||||||
DEFAULT_CHAT_MODEL = "gpt-4-turbo-preview"
|
DEFAULT_CHAT_MODEL = "gpt-4-turbo-preview"
|
||||||
DEFAULT_CHAT_MODEL_AUTH_TYPE = AuthType.APIKey
|
DEFAULT_CHAT_MODEL_AUTH_TYPE = AuthType.APIKey
|
||||||
DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model"
|
DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model"
|
||||||
DEFAULT_EMBEDDING_MODEL_TYPE = ModelType.OpenAIEmbedding
|
DEFAULT_EMBEDDING_MODEL_TYPE = ModelType.Embedding
|
||||||
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
|
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||||
DEFAULT_EMBEDDING_MODEL_AUTH_TYPE = AuthType.APIKey
|
DEFAULT_EMBEDDING_MODEL_AUTH_TYPE = AuthType.APIKey
|
||||||
|
DEFAULT_MODEL_PROVIDER = "openai"
|
||||||
DEFAULT_VECTOR_STORE_ID = "default_vector_store"
|
DEFAULT_VECTOR_STORE_ID = "default_vector_store"
|
||||||
|
|
||||||
ENCODING_MODEL = "cl100k_base"
|
ENCODING_MODEL = "cl100k_base"
|
||||||
@ -325,10 +326,10 @@ class LanguageModelDefaults:
|
|||||||
proxy: None = None
|
proxy: None = None
|
||||||
audience: None = None
|
audience: None = None
|
||||||
model_supports_json: None = None
|
model_supports_json: None = None
|
||||||
tokens_per_minute: Literal["auto"] = "auto"
|
tokens_per_minute: None = None
|
||||||
requests_per_minute: Literal["auto"] = "auto"
|
requests_per_minute: None = None
|
||||||
rate_limit_strategy: str | None = "static"
|
rate_limit_strategy: str | None = "static"
|
||||||
retry_strategy: str = "native"
|
retry_strategy: str = "exponential_backoff"
|
||||||
max_retries: int = 10
|
max_retries: int = 10
|
||||||
max_retry_wait: float = 10.0
|
max_retry_wait: float = 10.0
|
||||||
concurrent_requests: int = 25
|
concurrent_requests: int = 25
|
||||||
|
|||||||
@ -33,15 +33,6 @@ class AzureApiVersionMissingError(ValueError):
|
|||||||
super().__init__(msg)
|
super().__init__(msg)
|
||||||
|
|
||||||
|
|
||||||
class AzureDeploymentNameMissingError(ValueError):
|
|
||||||
"""Azure Deployment Name missing error."""
|
|
||||||
|
|
||||||
def __init__(self, llm_type: str) -> None:
|
|
||||||
"""Init method definition."""
|
|
||||||
msg = f"Deployment name is required for {llm_type}. Please rerun `graphrag init` set the deployment_name."
|
|
||||||
super().__init__(msg)
|
|
||||||
|
|
||||||
|
|
||||||
class LanguageModelConfigMissingError(ValueError):
|
class LanguageModelConfigMissingError(ValueError):
|
||||||
"""Missing model configuration error."""
|
"""Missing model configuration error."""
|
||||||
|
|
||||||
|
|||||||
@ -19,41 +19,34 @@ INIT_YAML = f"""\
|
|||||||
|
|
||||||
models:
|
models:
|
||||||
{defs.DEFAULT_CHAT_MODEL_ID}:
|
{defs.DEFAULT_CHAT_MODEL_ID}:
|
||||||
type: {defs.DEFAULT_CHAT_MODEL_TYPE.value} # or azure_openai_chat
|
type: {defs.DEFAULT_CHAT_MODEL_TYPE.value}
|
||||||
# api_base: https://<instance>.openai.azure.com
|
model_provider: {defs.DEFAULT_MODEL_PROVIDER}
|
||||||
# api_version: 2024-05-01-preview
|
|
||||||
auth_type: {defs.DEFAULT_CHAT_MODEL_AUTH_TYPE.value} # or azure_managed_identity
|
auth_type: {defs.DEFAULT_CHAT_MODEL_AUTH_TYPE.value} # or azure_managed_identity
|
||||||
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
|
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file, or remove if managed identity
|
||||||
# audience: "https://cognitiveservices.azure.com/.default"
|
|
||||||
# organization: <organization_id>
|
|
||||||
model: {defs.DEFAULT_CHAT_MODEL}
|
model: {defs.DEFAULT_CHAT_MODEL}
|
||||||
# deployment_name: <azure_model_deployment_name>
|
|
||||||
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
|
|
||||||
model_supports_json: true # recommended if this is available for your model.
|
|
||||||
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed
|
|
||||||
async_mode: {language_model_defaults.async_mode.value} # or asyncio
|
|
||||||
retry_strategy: native
|
|
||||||
max_retries: {language_model_defaults.max_retries}
|
|
||||||
tokens_per_minute: {language_model_defaults.tokens_per_minute} # set to null to disable rate limiting
|
|
||||||
requests_per_minute: {language_model_defaults.requests_per_minute} # set to null to disable rate limiting
|
|
||||||
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
|
|
||||||
type: {defs.DEFAULT_EMBEDDING_MODEL_TYPE.value} # or azure_openai_embedding
|
|
||||||
# api_base: https://<instance>.openai.azure.com
|
# api_base: https://<instance>.openai.azure.com
|
||||||
# api_version: 2024-05-01-preview
|
# api_version: 2024-05-01-preview
|
||||||
auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value} # or azure_managed_identity
|
|
||||||
api_key: ${{GRAPHRAG_API_KEY}}
|
|
||||||
# audience: "https://cognitiveservices.azure.com/.default"
|
|
||||||
# organization: <organization_id>
|
|
||||||
model: {defs.DEFAULT_EMBEDDING_MODEL}
|
|
||||||
# deployment_name: <azure_model_deployment_name>
|
|
||||||
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
|
|
||||||
model_supports_json: true # recommended if this is available for your model.
|
model_supports_json: true # recommended if this is available for your model.
|
||||||
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed
|
concurrent_requests: {language_model_defaults.concurrent_requests}
|
||||||
async_mode: {language_model_defaults.async_mode.value} # or asyncio
|
async_mode: {language_model_defaults.async_mode.value} # or asyncio
|
||||||
retry_strategy: native
|
retry_strategy: {language_model_defaults.retry_strategy}
|
||||||
max_retries: {language_model_defaults.max_retries}
|
max_retries: {language_model_defaults.max_retries}
|
||||||
tokens_per_minute: null # set to null to disable rate limiting or auto for dynamic
|
tokens_per_minute: null
|
||||||
requests_per_minute: null # set to null to disable rate limiting or auto for dynamic
|
requests_per_minute: null
|
||||||
|
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
|
||||||
|
type: {defs.DEFAULT_EMBEDDING_MODEL_TYPE.value}
|
||||||
|
model_provider: {defs.DEFAULT_MODEL_PROVIDER}
|
||||||
|
auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value}
|
||||||
|
api_key: ${{GRAPHRAG_API_KEY}}
|
||||||
|
model: {defs.DEFAULT_EMBEDDING_MODEL}
|
||||||
|
# api_base: https://<instance>.openai.azure.com
|
||||||
|
# api_version: 2024-05-01-preview
|
||||||
|
concurrent_requests: {language_model_defaults.concurrent_requests}
|
||||||
|
async_mode: {language_model_defaults.async_mode.value} # or asyncio
|
||||||
|
retry_strategy: {language_model_defaults.retry_strategy}
|
||||||
|
max_retries: {language_model_defaults.max_retries}
|
||||||
|
tokens_per_minute: null
|
||||||
|
requests_per_minute: null
|
||||||
|
|
||||||
### Input settings ###
|
### Input settings ###
|
||||||
|
|
||||||
@ -62,7 +55,6 @@ input:
|
|||||||
type: {graphrag_config_defaults.input.storage.type.value} # or blob
|
type: {graphrag_config_defaults.input.storage.type.value} # or blob
|
||||||
base_dir: "{graphrag_config_defaults.input.storage.base_dir}"
|
base_dir: "{graphrag_config_defaults.input.storage.base_dir}"
|
||||||
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
|
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
|
||||||
|
|
||||||
|
|
||||||
chunks:
|
chunks:
|
||||||
size: {graphrag_config_defaults.chunks.size}
|
size: {graphrag_config_defaults.chunks.size}
|
||||||
@ -90,7 +82,6 @@ vector_store:
|
|||||||
type: {vector_store_defaults.type}
|
type: {vector_store_defaults.type}
|
||||||
db_uri: {vector_store_defaults.db_uri}
|
db_uri: {vector_store_defaults.db_uri}
|
||||||
container_name: {vector_store_defaults.container_name}
|
container_name: {vector_store_defaults.container_name}
|
||||||
overwrite: {vector_store_defaults.overwrite}
|
|
||||||
|
|
||||||
### Workflow settings ###
|
### Workflow settings ###
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
"""Language model configuration."""
|
"""Language model configuration."""
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
@ -14,11 +15,12 @@ from graphrag.config.errors import (
|
|||||||
ApiKeyMissingError,
|
ApiKeyMissingError,
|
||||||
AzureApiBaseMissingError,
|
AzureApiBaseMissingError,
|
||||||
AzureApiVersionMissingError,
|
AzureApiVersionMissingError,
|
||||||
AzureDeploymentNameMissingError,
|
|
||||||
ConflictingSettingsError,
|
ConflictingSettingsError,
|
||||||
)
|
)
|
||||||
from graphrag.language_model.factory import ModelFactory
|
from graphrag.language_model.factory import ModelFactory
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LanguageModelConfig(BaseModel):
|
class LanguageModelConfig(BaseModel):
|
||||||
"""Language model configuration."""
|
"""Language model configuration."""
|
||||||
@ -214,7 +216,8 @@ class LanguageModelConfig(BaseModel):
|
|||||||
or self.type == ModelType.AzureOpenAIEmbedding
|
or self.type == ModelType.AzureOpenAIEmbedding
|
||||||
or self.model_provider == "azure" # indicates Litellm + AOI
|
or self.model_provider == "azure" # indicates Litellm + AOI
|
||||||
) and (self.deployment_name is None or self.deployment_name.strip() == ""):
|
) and (self.deployment_name is None or self.deployment_name.strip() == ""):
|
||||||
raise AzureDeploymentNameMissingError(self.type)
|
msg = f"deployment_name is not set for Azure-hosted model. This will default to your model name ({self.model}). If different, this should be set."
|
||||||
|
logger.debug(msg)
|
||||||
|
|
||||||
organization: str | None = Field(
|
organization: str | None = Field(
|
||||||
description="The organization to use for the LLM service.",
|
description="The organization to use for the LLM service.",
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
models:
|
models:
|
||||||
default_chat_model:
|
default_chat_model:
|
||||||
api_key: ${CUSTOM_API_KEY}
|
api_key: ${CUSTOM_API_KEY}
|
||||||
type: openai_chat
|
type: chat
|
||||||
|
model_provider: openai
|
||||||
model: gpt-4-turbo-preview
|
model: gpt-4-turbo-preview
|
||||||
default_embedding_model:
|
default_embedding_model:
|
||||||
api_key: ${CUSTOM_API_KEY}
|
api_key: ${CUSTOM_API_KEY}
|
||||||
type: openai_embedding
|
type: embedding
|
||||||
|
model_provider: openai
|
||||||
model: text-embedding-3-small
|
model: text-embedding-3-small
|
||||||
@ -1,9 +1,11 @@
|
|||||||
models:
|
models:
|
||||||
default_chat_model:
|
default_chat_model:
|
||||||
api_key: ${SOME_NON_EXISTENT_ENV_VAR}
|
api_key: ${SOME_NON_EXISTENT_ENV_VAR}
|
||||||
type: openai_chat
|
type: chat
|
||||||
|
model_provider: openai
|
||||||
model: gpt-4-turbo-preview
|
model: gpt-4-turbo-preview
|
||||||
default_embedding_model:
|
default_embedding_model:
|
||||||
api_key: ${SOME_NON_EXISTENT_ENV_VAR}
|
api_key: ${SOME_NON_EXISTENT_ENV_VAR}
|
||||||
type: openai_embedding
|
type: embedding
|
||||||
|
model_provider: openai
|
||||||
model: text-embedding-3-small
|
model: text-embedding-3-small
|
||||||
@ -133,19 +133,6 @@ def test_missing_azure_api_version() -> None:
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_missing_azure_deployment_name() -> None:
|
|
||||||
missing_deployment_name_config = base_azure_model_config.copy()
|
|
||||||
del missing_deployment_name_config["deployment_name"]
|
|
||||||
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
create_graphrag_config({
|
|
||||||
"models": {
|
|
||||||
defs.DEFAULT_CHAT_MODEL_ID: missing_deployment_name_config,
|
|
||||||
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
def test_default_config() -> None:
|
def test_default_config() -> None:
|
||||||
expected = get_default_graphrag_config()
|
expected = get_default_graphrag_config()
|
||||||
actual = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
actual = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||||
|
|||||||
@ -41,12 +41,14 @@ DEFAULT_CHAT_MODEL_CONFIG = {
|
|||||||
"api_key": FAKE_API_KEY,
|
"api_key": FAKE_API_KEY,
|
||||||
"type": defs.DEFAULT_CHAT_MODEL_TYPE.value,
|
"type": defs.DEFAULT_CHAT_MODEL_TYPE.value,
|
||||||
"model": defs.DEFAULT_CHAT_MODEL,
|
"model": defs.DEFAULT_CHAT_MODEL,
|
||||||
|
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFAULT_EMBEDDING_MODEL_CONFIG = {
|
DEFAULT_EMBEDDING_MODEL_CONFIG = {
|
||||||
"api_key": FAKE_API_KEY,
|
"api_key": FAKE_API_KEY,
|
||||||
"type": defs.DEFAULT_EMBEDDING_MODEL_TYPE.value,
|
"type": defs.DEFAULT_EMBEDDING_MODEL_TYPE.value,
|
||||||
"model": defs.DEFAULT_EMBEDDING_MODEL,
|
"model": defs.DEFAULT_EMBEDDING_MODEL,
|
||||||
|
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFAULT_MODEL_CONFIG = {
|
DEFAULT_MODEL_CONFIG = {
|
||||||
|
|||||||
@ -17,12 +17,14 @@ DEFAULT_CHAT_MODEL_CONFIG = {
|
|||||||
"api_key": FAKE_API_KEY,
|
"api_key": FAKE_API_KEY,
|
||||||
"type": defs.DEFAULT_CHAT_MODEL_TYPE.value,
|
"type": defs.DEFAULT_CHAT_MODEL_TYPE.value,
|
||||||
"model": defs.DEFAULT_CHAT_MODEL,
|
"model": defs.DEFAULT_CHAT_MODEL,
|
||||||
|
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFAULT_EMBEDDING_MODEL_CONFIG = {
|
DEFAULT_EMBEDDING_MODEL_CONFIG = {
|
||||||
"api_key": FAKE_API_KEY,
|
"api_key": FAKE_API_KEY,
|
||||||
"type": defs.DEFAULT_EMBEDDING_MODEL_TYPE.value,
|
"type": defs.DEFAULT_EMBEDDING_MODEL_TYPE.value,
|
||||||
"model": defs.DEFAULT_EMBEDDING_MODEL,
|
"model": defs.DEFAULT_EMBEDDING_MODEL,
|
||||||
|
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFAULT_MODEL_CONFIG = {
|
DEFAULT_MODEL_CONFIG = {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user