mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-13 16:47:20 +08:00
Init config cleanup (#2084)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
* Spruce up init_config output, including LiteLLM default * Remove deployment_name requirement for Azure * Semver * Add model_provider * Add default model_provider * Remove OBE test * Update minimal config for tests * Add model_provider to verb tests
This commit is contained in:
parent
2bd3922d8d
commit
6c86b0a7bb
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "minor",
|
||||
"description": "Set LiteLLM as default in init_content."
|
||||
}
|
||||
@ -6,7 +6,7 @@
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal
|
||||
from typing import ClassVar
|
||||
|
||||
from graphrag.config.embeddings import default_embeddings
|
||||
from graphrag.config.enums import (
|
||||
@ -46,13 +46,14 @@ from graphrag.language_model.providers.litellm.services.retry.retry import Retry
|
||||
|
||||
DEFAULT_OUTPUT_BASE_DIR = "output"
|
||||
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
|
||||
DEFAULT_CHAT_MODEL_TYPE = ModelType.OpenAIChat
|
||||
DEFAULT_CHAT_MODEL_TYPE = ModelType.Chat
|
||||
DEFAULT_CHAT_MODEL = "gpt-4-turbo-preview"
|
||||
DEFAULT_CHAT_MODEL_AUTH_TYPE = AuthType.APIKey
|
||||
DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model"
|
||||
DEFAULT_EMBEDDING_MODEL_TYPE = ModelType.OpenAIEmbedding
|
||||
DEFAULT_EMBEDDING_MODEL_TYPE = ModelType.Embedding
|
||||
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
DEFAULT_EMBEDDING_MODEL_AUTH_TYPE = AuthType.APIKey
|
||||
DEFAULT_MODEL_PROVIDER = "openai"
|
||||
DEFAULT_VECTOR_STORE_ID = "default_vector_store"
|
||||
|
||||
ENCODING_MODEL = "cl100k_base"
|
||||
@ -325,10 +326,10 @@ class LanguageModelDefaults:
|
||||
proxy: None = None
|
||||
audience: None = None
|
||||
model_supports_json: None = None
|
||||
tokens_per_minute: Literal["auto"] = "auto"
|
||||
requests_per_minute: Literal["auto"] = "auto"
|
||||
tokens_per_minute: None = None
|
||||
requests_per_minute: None = None
|
||||
rate_limit_strategy: str | None = "static"
|
||||
retry_strategy: str = "native"
|
||||
retry_strategy: str = "exponential_backoff"
|
||||
max_retries: int = 10
|
||||
max_retry_wait: float = 10.0
|
||||
concurrent_requests: int = 25
|
||||
|
||||
@ -33,15 +33,6 @@ class AzureApiVersionMissingError(ValueError):
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class AzureDeploymentNameMissingError(ValueError):
|
||||
"""Azure Deployment Name missing error."""
|
||||
|
||||
def __init__(self, llm_type: str) -> None:
|
||||
"""Init method definition."""
|
||||
msg = f"Deployment name is required for {llm_type}. Please rerun `graphrag init` set the deployment_name."
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class LanguageModelConfigMissingError(ValueError):
|
||||
"""Missing model configuration error."""
|
||||
|
||||
|
||||
@ -19,41 +19,34 @@ INIT_YAML = f"""\
|
||||
|
||||
models:
|
||||
{defs.DEFAULT_CHAT_MODEL_ID}:
|
||||
type: {defs.DEFAULT_CHAT_MODEL_TYPE.value} # or azure_openai_chat
|
||||
# api_base: https://<instance>.openai.azure.com
|
||||
# api_version: 2024-05-01-preview
|
||||
type: {defs.DEFAULT_CHAT_MODEL_TYPE.value}
|
||||
model_provider: {defs.DEFAULT_MODEL_PROVIDER}
|
||||
auth_type: {defs.DEFAULT_CHAT_MODEL_AUTH_TYPE.value} # or azure_managed_identity
|
||||
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
|
||||
# audience: "https://cognitiveservices.azure.com/.default"
|
||||
# organization: <organization_id>
|
||||
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file, or remove if managed identity
|
||||
model: {defs.DEFAULT_CHAT_MODEL}
|
||||
# deployment_name: <azure_model_deployment_name>
|
||||
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
|
||||
model_supports_json: true # recommended if this is available for your model.
|
||||
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed
|
||||
async_mode: {language_model_defaults.async_mode.value} # or asyncio
|
||||
retry_strategy: native
|
||||
max_retries: {language_model_defaults.max_retries}
|
||||
tokens_per_minute: {language_model_defaults.tokens_per_minute} # set to null to disable rate limiting
|
||||
requests_per_minute: {language_model_defaults.requests_per_minute} # set to null to disable rate limiting
|
||||
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
|
||||
type: {defs.DEFAULT_EMBEDDING_MODEL_TYPE.value} # or azure_openai_embedding
|
||||
# api_base: https://<instance>.openai.azure.com
|
||||
# api_version: 2024-05-01-preview
|
||||
auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value} # or azure_managed_identity
|
||||
api_key: ${{GRAPHRAG_API_KEY}}
|
||||
# audience: "https://cognitiveservices.azure.com/.default"
|
||||
# organization: <organization_id>
|
||||
model: {defs.DEFAULT_EMBEDDING_MODEL}
|
||||
# deployment_name: <azure_model_deployment_name>
|
||||
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
|
||||
model_supports_json: true # recommended if this is available for your model.
|
||||
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed
|
||||
concurrent_requests: {language_model_defaults.concurrent_requests}
|
||||
async_mode: {language_model_defaults.async_mode.value} # or asyncio
|
||||
retry_strategy: native
|
||||
retry_strategy: {language_model_defaults.retry_strategy}
|
||||
max_retries: {language_model_defaults.max_retries}
|
||||
tokens_per_minute: null # set to null to disable rate limiting or auto for dynamic
|
||||
requests_per_minute: null # set to null to disable rate limiting or auto for dynamic
|
||||
tokens_per_minute: null
|
||||
requests_per_minute: null
|
||||
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
|
||||
type: {defs.DEFAULT_EMBEDDING_MODEL_TYPE.value}
|
||||
model_provider: {defs.DEFAULT_MODEL_PROVIDER}
|
||||
auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value}
|
||||
api_key: ${{GRAPHRAG_API_KEY}}
|
||||
model: {defs.DEFAULT_EMBEDDING_MODEL}
|
||||
# api_base: https://<instance>.openai.azure.com
|
||||
# api_version: 2024-05-01-preview
|
||||
concurrent_requests: {language_model_defaults.concurrent_requests}
|
||||
async_mode: {language_model_defaults.async_mode.value} # or asyncio
|
||||
retry_strategy: {language_model_defaults.retry_strategy}
|
||||
max_retries: {language_model_defaults.max_retries}
|
||||
tokens_per_minute: null
|
||||
requests_per_minute: null
|
||||
|
||||
### Input settings ###
|
||||
|
||||
@ -62,7 +55,6 @@ input:
|
||||
type: {graphrag_config_defaults.input.storage.type.value} # or blob
|
||||
base_dir: "{graphrag_config_defaults.input.storage.base_dir}"
|
||||
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
|
||||
|
||||
|
||||
chunks:
|
||||
size: {graphrag_config_defaults.chunks.size}
|
||||
@ -90,7 +82,6 @@ vector_store:
|
||||
type: {vector_store_defaults.type}
|
||||
db_uri: {vector_store_defaults.db_uri}
|
||||
container_name: {vector_store_defaults.container_name}
|
||||
overwrite: {vector_store_defaults.overwrite}
|
||||
|
||||
### Workflow settings ###
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
"""Language model configuration."""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
import tiktoken
|
||||
@ -14,11 +15,12 @@ from graphrag.config.errors import (
|
||||
ApiKeyMissingError,
|
||||
AzureApiBaseMissingError,
|
||||
AzureApiVersionMissingError,
|
||||
AzureDeploymentNameMissingError,
|
||||
ConflictingSettingsError,
|
||||
)
|
||||
from graphrag.language_model.factory import ModelFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LanguageModelConfig(BaseModel):
|
||||
"""Language model configuration."""
|
||||
@ -214,7 +216,8 @@ class LanguageModelConfig(BaseModel):
|
||||
or self.type == ModelType.AzureOpenAIEmbedding
|
||||
or self.model_provider == "azure" # indicates Litellm + AOI
|
||||
) and (self.deployment_name is None or self.deployment_name.strip() == ""):
|
||||
raise AzureDeploymentNameMissingError(self.type)
|
||||
msg = f"deployment_name is not set for Azure-hosted model. This will default to your model name ({self.model}). If different, this should be set."
|
||||
logger.debug(msg)
|
||||
|
||||
organization: str | None = Field(
|
||||
description="The organization to use for the LLM service.",
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
models:
|
||||
default_chat_model:
|
||||
api_key: ${CUSTOM_API_KEY}
|
||||
type: openai_chat
|
||||
type: chat
|
||||
model_provider: openai
|
||||
model: gpt-4-turbo-preview
|
||||
default_embedding_model:
|
||||
api_key: ${CUSTOM_API_KEY}
|
||||
type: openai_embedding
|
||||
type: embedding
|
||||
model_provider: openai
|
||||
model: text-embedding-3-small
|
||||
@ -1,9 +1,11 @@
|
||||
models:
|
||||
default_chat_model:
|
||||
api_key: ${SOME_NON_EXISTENT_ENV_VAR}
|
||||
type: openai_chat
|
||||
type: chat
|
||||
model_provider: openai
|
||||
model: gpt-4-turbo-preview
|
||||
default_embedding_model:
|
||||
api_key: ${SOME_NON_EXISTENT_ENV_VAR}
|
||||
type: openai_embedding
|
||||
type: embedding
|
||||
model_provider: openai
|
||||
model: text-embedding-3-small
|
||||
@ -133,19 +133,6 @@ def test_missing_azure_api_version() -> None:
|
||||
})
|
||||
|
||||
|
||||
def test_missing_azure_deployment_name() -> None:
|
||||
missing_deployment_name_config = base_azure_model_config.copy()
|
||||
del missing_deployment_name_config["deployment_name"]
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
create_graphrag_config({
|
||||
"models": {
|
||||
defs.DEFAULT_CHAT_MODEL_ID: missing_deployment_name_config,
|
||||
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
def test_default_config() -> None:
|
||||
expected = get_default_graphrag_config()
|
||||
actual = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
|
||||
@ -41,12 +41,14 @@ DEFAULT_CHAT_MODEL_CONFIG = {
|
||||
"api_key": FAKE_API_KEY,
|
||||
"type": defs.DEFAULT_CHAT_MODEL_TYPE.value,
|
||||
"model": defs.DEFAULT_CHAT_MODEL,
|
||||
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
|
||||
}
|
||||
|
||||
DEFAULT_EMBEDDING_MODEL_CONFIG = {
|
||||
"api_key": FAKE_API_KEY,
|
||||
"type": defs.DEFAULT_EMBEDDING_MODEL_TYPE.value,
|
||||
"model": defs.DEFAULT_EMBEDDING_MODEL,
|
||||
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
|
||||
}
|
||||
|
||||
DEFAULT_MODEL_CONFIG = {
|
||||
|
||||
@ -17,12 +17,14 @@ DEFAULT_CHAT_MODEL_CONFIG = {
|
||||
"api_key": FAKE_API_KEY,
|
||||
"type": defs.DEFAULT_CHAT_MODEL_TYPE.value,
|
||||
"model": defs.DEFAULT_CHAT_MODEL,
|
||||
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
|
||||
}
|
||||
|
||||
DEFAULT_EMBEDDING_MODEL_CONFIG = {
|
||||
"api_key": FAKE_API_KEY,
|
||||
"type": defs.DEFAULT_EMBEDDING_MODEL_TYPE.value,
|
||||
"model": defs.DEFAULT_EMBEDDING_MODEL,
|
||||
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
|
||||
}
|
||||
|
||||
DEFAULT_MODEL_CONFIG = {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user