mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
Improve default llm retry logic to be more optimized (#1701)
This commit is contained in:
parent
b8b949f3bb
commit
f14cda2b6d
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "patch",
|
||||||
|
"description": "add dynamic retry logic."
|
||||||
|
}
|
||||||
@ -13,6 +13,7 @@ Backwards compatibility is not guaranteed at this time.
|
|||||||
|
|
||||||
from pydantic import PositiveInt, validate_call
|
from pydantic import PositiveInt, validate_call
|
||||||
|
|
||||||
|
import graphrag.config.defaults as defs
|
||||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||||
from graphrag.index.llm.load_llm import load_llm
|
from graphrag.index.llm.load_llm import load_llm
|
||||||
@ -95,8 +96,14 @@ async def generate_indexing_prompts(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create LLM from config
|
# Create LLM from config
|
||||||
# TODO: Expose way to specify Prompt Tuning model ID through config
|
# TODO: Expose a way to specify Prompt Tuning model ID through config
|
||||||
default_llm_settings = config.get_language_model_config(PROMPT_TUNING_MODEL_ID)
|
default_llm_settings = config.get_language_model_config(PROMPT_TUNING_MODEL_ID)
|
||||||
|
|
||||||
|
# if max_retries is not set, inject a dynamically assigned value based on the number of expected LLM calls
|
||||||
|
# to be made or fallback to a default value in the worst case
|
||||||
|
if default_llm_settings.max_retries == -1:
|
||||||
|
default_llm_settings.max_retries = min(len(doc_list), defs.LLM_MAX_RETRIES)
|
||||||
|
|
||||||
llm = load_llm(
|
llm = load_llm(
|
||||||
"prompt_tuning",
|
"prompt_tuning",
|
||||||
default_llm_settings,
|
default_llm_settings,
|
||||||
|
|||||||
@ -24,7 +24,7 @@ DEFAULT_CHAT_MODEL_ID = "default_chat_model"
|
|||||||
DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model"
|
DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model"
|
||||||
ASYNC_MODE = AsyncType.Threaded
|
ASYNC_MODE = AsyncType.Threaded
|
||||||
ENCODING_MODEL = "cl100k_base"
|
ENCODING_MODEL = "cl100k_base"
|
||||||
AZURE_AUDIENCE = "https://cognitiveservices.azure.com/.default"
|
COGNITIVE_SERVICES_AUDIENCE = "https://cognitiveservices.azure.com/.default"
|
||||||
AUTH_TYPE = AuthType.APIKey
|
AUTH_TYPE = AuthType.APIKey
|
||||||
#
|
#
|
||||||
# LLM Parameters
|
# LLM Parameters
|
||||||
@ -39,15 +39,12 @@ LLM_N = 1
|
|||||||
LLM_REQUEST_TIMEOUT = 180.0
|
LLM_REQUEST_TIMEOUT = 180.0
|
||||||
LLM_TOKENS_PER_MINUTE = 50_000
|
LLM_TOKENS_PER_MINUTE = 50_000
|
||||||
LLM_REQUESTS_PER_MINUTE = 1_000
|
LLM_REQUESTS_PER_MINUTE = 1_000
|
||||||
|
RETRY_STRATEGY = "native"
|
||||||
LLM_MAX_RETRIES = 10
|
LLM_MAX_RETRIES = 10
|
||||||
LLM_MAX_RETRY_WAIT = 10.0
|
LLM_MAX_RETRY_WAIT = 10.0
|
||||||
LLM_PRESENCE_PENALTY = 0.0
|
LLM_PRESENCE_PENALTY = 0.0
|
||||||
LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION = True
|
|
||||||
LLM_CONCURRENT_REQUESTS = 25
|
LLM_CONCURRENT_REQUESTS = 25
|
||||||
|
|
||||||
PARALLELIZATION_STAGGER = 0.3
|
|
||||||
PARALLELIZATION_NUM_THREADS = 50
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Text embedding
|
# Text embedding
|
||||||
#
|
#
|
||||||
|
|||||||
@ -14,32 +14,41 @@ INIT_YAML = f"""\
|
|||||||
|
|
||||||
models:
|
models:
|
||||||
{defs.DEFAULT_CHAT_MODEL_ID}:
|
{defs.DEFAULT_CHAT_MODEL_ID}:
|
||||||
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
|
|
||||||
type: {defs.LLM_TYPE.value} # or azure_openai_chat
|
type: {defs.LLM_TYPE.value} # or azure_openai_chat
|
||||||
|
# api_base: https://<instance>.openai.azure.com
|
||||||
|
# api_version: 2024-05-01-preview
|
||||||
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
|
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
|
||||||
|
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
|
||||||
|
# audience: "https://cognitiveservices.azure.com/.default"
|
||||||
|
# organization: <organization_id>
|
||||||
model: {defs.LLM_MODEL}
|
model: {defs.LLM_MODEL}
|
||||||
|
# deployment_name: <azure_model_deployment_name>
|
||||||
|
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
|
||||||
model_supports_json: true # recommended if this is available for your model.
|
model_supports_json: true # recommended if this is available for your model.
|
||||||
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
|
concurrent_requests: {defs.LLM_CONCURRENT_REQUESTS} # max number of simultaneous LLM requests allowed
|
||||||
parallelization_stagger: {defs.PARALLELIZATION_STAGGER}
|
|
||||||
async_mode: {defs.ASYNC_MODE.value} # or asyncio
|
async_mode: {defs.ASYNC_MODE.value} # or asyncio
|
||||||
# audience: "https://cognitiveservices.azure.com/.default"
|
retry_strategy: native
|
||||||
# api_base: https://<instance>.openai.azure.com
|
max_retries: -1 # set to -1 for dynamic retry logic (most optimal setting based on server response)
|
||||||
# api_version: 2024-02-15-preview
|
tokens_per_minute: 0 # set to 0 to disable rate limiting
|
||||||
# organization: <organization_id>
|
requests_per_minute: 0 # set to 0 to disable rate limiting
|
||||||
# deployment_name: <azure_model_deployment_name>
|
|
||||||
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
|
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
|
||||||
api_key: ${{GRAPHRAG_API_KEY}}
|
|
||||||
type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding
|
type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding
|
||||||
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
|
|
||||||
model: {defs.EMBEDDING_MODEL}
|
|
||||||
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
|
|
||||||
parallelization_stagger: {defs.PARALLELIZATION_STAGGER}
|
|
||||||
async_mode: {defs.ASYNC_MODE.value} # or asyncio
|
|
||||||
# api_base: https://<instance>.openai.azure.com
|
# api_base: https://<instance>.openai.azure.com
|
||||||
# api_version: 2024-02-15-preview
|
# api_version: 2024-05-01-preview
|
||||||
|
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
|
||||||
|
api_key: ${{GRAPHRAG_API_KEY}}
|
||||||
# audience: "https://cognitiveservices.azure.com/.default"
|
# audience: "https://cognitiveservices.azure.com/.default"
|
||||||
# organization: <organization_id>
|
# organization: <organization_id>
|
||||||
|
model: {defs.EMBEDDING_MODEL}
|
||||||
# deployment_name: <azure_model_deployment_name>
|
# deployment_name: <azure_model_deployment_name>
|
||||||
|
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
|
||||||
|
model_supports_json: true # recommended if this is available for your model.
|
||||||
|
concurrent_requests: {defs.LLM_CONCURRENT_REQUESTS} # max number of simultaneous LLM requests allowed
|
||||||
|
async_mode: {defs.ASYNC_MODE.value} # or asyncio
|
||||||
|
retry_strategy: native
|
||||||
|
max_retries: -1 # set to -1 for dynamic retry logic (most optimal setting based on server response)
|
||||||
|
tokens_per_minute: 0 # set to 0 to disable rate limiting
|
||||||
|
requests_per_minute: 0 # set to 0 to disable rate limiting
|
||||||
|
|
||||||
vector_store:
|
vector_store:
|
||||||
{defs.VECTOR_STORE_DEFAULT_ID}:
|
{defs.VECTOR_STORE_DEFAULT_ID}:
|
||||||
|
|||||||
@ -49,8 +49,7 @@ class CommunityReportsConfig(BaseModel):
|
|||||||
return self.strategy or {
|
return self.strategy or {
|
||||||
"type": CreateCommunityReportsStrategyType.graph_intelligence,
|
"type": CreateCommunityReportsStrategyType.graph_intelligence,
|
||||||
"llm": model_config.model_dump(),
|
"llm": model_config.model_dump(),
|
||||||
"stagger": model_config.parallelization_stagger,
|
"num_threads": model_config.concurrent_requests,
|
||||||
"num_threads": model_config.parallelization_num_threads,
|
|
||||||
"graph_prompt": (Path(root_dir) / self.graph_prompt).read_text(
|
"graph_prompt": (Path(root_dir) / self.graph_prompt).read_text(
|
||||||
encoding="utf-8"
|
encoding="utf-8"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -46,8 +46,7 @@ class ClaimExtractionConfig(BaseModel):
|
|||||||
"""Get the resolved claim extraction strategy."""
|
"""Get the resolved claim extraction strategy."""
|
||||||
return self.strategy or {
|
return self.strategy or {
|
||||||
"llm": model_config.model_dump(),
|
"llm": model_config.model_dump(),
|
||||||
"stagger": model_config.parallelization_stagger,
|
"num_threads": model_config.concurrent_requests,
|
||||||
"num_threads": model_config.parallelization_num_threads,
|
|
||||||
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
|
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
|
||||||
encoding="utf-8"
|
encoding="utf-8"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -47,8 +47,7 @@ class ExtractGraphConfig(BaseModel):
|
|||||||
return self.strategy or {
|
return self.strategy or {
|
||||||
"type": ExtractEntityStrategyType.graph_intelligence,
|
"type": ExtractEntityStrategyType.graph_intelligence,
|
||||||
"llm": model_config.model_dump(),
|
"llm": model_config.model_dump(),
|
||||||
"stagger": model_config.parallelization_stagger,
|
"num_threads": model_config.concurrent_requests,
|
||||||
"num_threads": model_config.parallelization_num_threads,
|
|
||||||
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
|
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
|
||||||
encoding="utf-8"
|
encoding="utf-8"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -64,7 +64,7 @@ class ExtractGraphNLPConfig(BaseModel):
|
|||||||
text_analyzer: TextAnalyzerConfig = Field(
|
text_analyzer: TextAnalyzerConfig = Field(
|
||||||
description="The text analyzer configuration.", default=TextAnalyzerConfig()
|
description="The text analyzer configuration.", default=TextAnalyzerConfig()
|
||||||
)
|
)
|
||||||
parallelization_num_threads: int = Field(
|
concurrent_requests: int = Field(
|
||||||
description="The number of threads to use for the extraction process.",
|
description="The number of threads to use for the extraction process.",
|
||||||
default=defs.PARALLELIZATION_NUM_THREADS,
|
default=defs.LLM_CONCURRENT_REQUESTS,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class LanguageModelConfig(BaseModel):
|
|||||||
API Key is required when using OpenAI API
|
API Key is required when using OpenAI API
|
||||||
or when using Azure API with API Key authentication.
|
or when using Azure API with API Key authentication.
|
||||||
For the time being, this check is extra verbose for clarity.
|
For the time being, this check is extra verbose for clarity.
|
||||||
It will also through an exception if an API Key is provided
|
It will also raise an exception if an API Key is provided
|
||||||
when one is not expected such as the case of using Azure
|
when one is not expected such as the case of using Azure
|
||||||
Managed Identity.
|
Managed Identity.
|
||||||
|
|
||||||
@ -199,6 +199,10 @@ class LanguageModelConfig(BaseModel):
|
|||||||
description="The number of requests per minute to use for the LLM service.",
|
description="The number of requests per minute to use for the LLM service.",
|
||||||
default=defs.LLM_REQUESTS_PER_MINUTE,
|
default=defs.LLM_REQUESTS_PER_MINUTE,
|
||||||
)
|
)
|
||||||
|
retry_strategy: str = Field(
|
||||||
|
description="The retry strategy to use for the LLM service.",
|
||||||
|
default=defs.RETRY_STRATEGY,
|
||||||
|
)
|
||||||
max_retries: int = Field(
|
max_retries: int = Field(
|
||||||
description="The maximum number of retries to use for the LLM service.",
|
description="The maximum number of retries to use for the LLM service.",
|
||||||
default=defs.LLM_MAX_RETRIES,
|
default=defs.LLM_MAX_RETRIES,
|
||||||
@ -207,10 +211,6 @@ class LanguageModelConfig(BaseModel):
|
|||||||
description="The maximum retry wait to use for the LLM service.",
|
description="The maximum retry wait to use for the LLM service.",
|
||||||
default=defs.LLM_MAX_RETRY_WAIT,
|
default=defs.LLM_MAX_RETRY_WAIT,
|
||||||
)
|
)
|
||||||
sleep_on_rate_limit_recommendation: bool = Field(
|
|
||||||
description="Whether to sleep on rate limit recommendations.",
|
|
||||||
default=defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION,
|
|
||||||
)
|
|
||||||
concurrent_requests: int = Field(
|
concurrent_requests: int = Field(
|
||||||
description="Whether to use concurrent requests for the LLM service.",
|
description="Whether to use concurrent requests for the LLM service.",
|
||||||
default=defs.LLM_CONCURRENT_REQUESTS,
|
default=defs.LLM_CONCURRENT_REQUESTS,
|
||||||
@ -218,14 +218,6 @@ class LanguageModelConfig(BaseModel):
|
|||||||
responses: list[str | BaseModel] | None = Field(
|
responses: list[str | BaseModel] | None = Field(
|
||||||
default=None, description="Static responses to use in mock mode."
|
default=None, description="Static responses to use in mock mode."
|
||||||
)
|
)
|
||||||
parallelization_stagger: float = Field(
|
|
||||||
description="The stagger to use for the LLM service.",
|
|
||||||
default=defs.PARALLELIZATION_STAGGER,
|
|
||||||
)
|
|
||||||
parallelization_num_threads: int = Field(
|
|
||||||
description="The number of threads to use for the LLM service.",
|
|
||||||
default=defs.PARALLELIZATION_NUM_THREADS,
|
|
||||||
)
|
|
||||||
async_mode: AsyncType = Field(
|
async_mode: AsyncType = Field(
|
||||||
description="The async mode to use.", default=defs.ASYNC_MODE
|
description="The async mode to use.", default=defs.ASYNC_MODE
|
||||||
)
|
)
|
||||||
|
|||||||
@ -40,8 +40,7 @@ class SummarizeDescriptionsConfig(BaseModel):
|
|||||||
return self.strategy or {
|
return self.strategy or {
|
||||||
"type": SummarizeStrategyType.graph_intelligence,
|
"type": SummarizeStrategyType.graph_intelligence,
|
||||||
"llm": model_config.model_dump(),
|
"llm": model_config.model_dump(),
|
||||||
"stagger": model_config.parallelization_stagger,
|
"num_threads": model_config.concurrent_requests,
|
||||||
"num_threads": model_config.parallelization_num_threads,
|
|
||||||
"summarize_prompt": (Path(root_dir) / self.prompt).read_text(
|
"summarize_prompt": (Path(root_dir) / self.prompt).read_text(
|
||||||
encoding="utf-8"
|
encoding="utf-8"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -48,8 +48,7 @@ class TextEmbeddingConfig(BaseModel):
|
|||||||
return self.strategy or {
|
return self.strategy or {
|
||||||
"type": TextEmbedStrategyType.openai,
|
"type": TextEmbedStrategyType.openai,
|
||||||
"llm": model_config.model_dump(),
|
"llm": model_config.model_dump(),
|
||||||
"stagger": model_config.parallelization_stagger,
|
"num_threads": model_config.concurrent_requests,
|
||||||
"num_threads": model_config.parallelization_num_threads,
|
|
||||||
"batch_size": self.batch_size,
|
"batch_size": self.batch_size,
|
||||||
"batch_max_tokens": self.batch_max_tokens,
|
"batch_max_tokens": self.batch_max_tokens,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -37,7 +37,7 @@ async def extract_graph_nlp(
|
|||||||
text_units,
|
text_units,
|
||||||
text_analyzer=text_analyzer,
|
text_analyzer=text_analyzer,
|
||||||
normalize_edge_weights=extraction_config.normalize_edge_weights,
|
normalize_edge_weights=extraction_config.normalize_edge_weights,
|
||||||
num_threads=extraction_config.parallelization_num_threads,
|
num_threads=extraction_config.concurrent_requests,
|
||||||
cache=cache,
|
cache=cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -8,8 +8,9 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from fnllm import ChatLLM, EmbeddingsLLM, JsonStrategy, LLMEvents
|
from fnllm.base.config import JsonStrategy, RetryStrategy
|
||||||
from fnllm.caching import Cache as LLMCache
|
from fnllm.caching import Cache as LLMCache
|
||||||
|
from fnllm.events import LLMEvents
|
||||||
from fnllm.openai import (
|
from fnllm.openai import (
|
||||||
AzureOpenAIConfig,
|
AzureOpenAIConfig,
|
||||||
OpenAIConfig,
|
OpenAIConfig,
|
||||||
@ -30,6 +31,8 @@ from graphrag.index.llm.manager import ChatLLMSingleton, EmbeddingsLLMSingleton
|
|||||||
from .mock_llm import MockChatLLM
|
from .mock_llm import MockChatLLM
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from fnllm.types import ChatLLM, EmbeddingsLLM
|
||||||
|
|
||||||
from graphrag.cache.pipeline_cache import PipelineCache
|
from graphrag.cache.pipeline_cache import PipelineCache
|
||||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||||
from graphrag.index.typing import ErrorHandlerFn
|
from graphrag.index.typing import ErrorHandlerFn
|
||||||
@ -209,7 +212,7 @@ def _create_openai_config(config: LanguageModelConfig, azure: bool) -> OpenAICon
|
|||||||
msg = "Azure OpenAI Chat LLM requires an API base"
|
msg = "Azure OpenAI Chat LLM requires an API base"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
audience = config.audience or defs.AZURE_AUDIENCE
|
audience = config.audience or defs.COGNITIVE_SERVICES_AUDIENCE
|
||||||
return AzureOpenAIConfig(
|
return AzureOpenAIConfig(
|
||||||
api_key=config.api_key,
|
api_key=config.api_key,
|
||||||
endpoint=config.api_base,
|
endpoint=config.api_base,
|
||||||
@ -220,19 +223,22 @@ def _create_openai_config(config: LanguageModelConfig, azure: bool) -> OpenAICon
|
|||||||
max_retry_wait=config.max_retry_wait,
|
max_retry_wait=config.max_retry_wait,
|
||||||
requests_per_minute=config.requests_per_minute,
|
requests_per_minute=config.requests_per_minute,
|
||||||
tokens_per_minute=config.tokens_per_minute,
|
tokens_per_minute=config.tokens_per_minute,
|
||||||
cognitive_services_endpoint=audience,
|
audience=audience,
|
||||||
|
retry_strategy=RetryStrategy(config.retry_strategy),
|
||||||
timeout=config.request_timeout,
|
timeout=config.request_timeout,
|
||||||
max_concurrency=config.concurrent_requests,
|
max_concurrency=config.concurrent_requests,
|
||||||
model=config.model,
|
model=config.model,
|
||||||
encoding=encoding_model,
|
encoding=encoding_model,
|
||||||
deployment=config.deployment_name,
|
deployment=config.deployment_name,
|
||||||
chat_parameters=chat_parameters,
|
chat_parameters=chat_parameters,
|
||||||
|
sleep_on_rate_limit_recommendation=True,
|
||||||
)
|
)
|
||||||
return PublicOpenAIConfig(
|
return PublicOpenAIConfig(
|
||||||
api_key=config.api_key,
|
api_key=config.api_key,
|
||||||
base_url=config.api_base,
|
base_url=config.api_base,
|
||||||
json_strategy=json_strategy,
|
json_strategy=json_strategy,
|
||||||
organization=config.organization,
|
organization=config.organization,
|
||||||
|
retry_strategy=RetryStrategy(config.retry_strategy),
|
||||||
max_retries=config.max_retries,
|
max_retries=config.max_retries,
|
||||||
max_retry_wait=config.max_retry_wait,
|
max_retry_wait=config.max_retry_wait,
|
||||||
requests_per_minute=config.requests_per_minute,
|
requests_per_minute=config.requests_per_minute,
|
||||||
@ -242,6 +248,7 @@ def _create_openai_config(config: LanguageModelConfig, azure: bool) -> OpenAICon
|
|||||||
model=config.model,
|
model=config.model,
|
||||||
encoding=encoding_model,
|
encoding=encoding_model,
|
||||||
chat_parameters=chat_parameters,
|
chat_parameters=chat_parameters,
|
||||||
|
sleep_on_rate_limit_recommendation=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
from functools import cache
|
from functools import cache
|
||||||
|
|
||||||
from fnllm import ChatLLM, EmbeddingsLLM
|
from fnllm.types import ChatLLM, EmbeddingsLLM
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
|
|||||||
@ -5,8 +5,9 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from fnllm import ChatLLM, LLMInput, LLMOutput
|
from fnllm.types import ChatLLM
|
||||||
from fnllm.types.generics import THistoryEntry, TJsonModel, TModelParameters
|
from fnllm.types.generics import THistoryEntry, TJsonModel, TModelParameters
|
||||||
|
from fnllm.types.io import LLMInput, LLMOutput
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import Unpack
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
|
|||||||
@ -116,10 +116,10 @@ async def _text_embed_in_memory(
|
|||||||
):
|
):
|
||||||
strategy_type = strategy["type"]
|
strategy_type = strategy["type"]
|
||||||
strategy_exec = load_strategy(strategy_type)
|
strategy_exec = load_strategy(strategy_type)
|
||||||
strategy_args = {**strategy}
|
strategy_config = {**strategy}
|
||||||
|
|
||||||
texts: list[str] = input[embed_column].to_numpy().tolist()
|
texts: list[str] = input[embed_column].to_numpy().tolist()
|
||||||
result = await strategy_exec(texts, callbacks, cache, strategy_args)
|
result = await strategy_exec(texts, callbacks, cache, strategy_config)
|
||||||
|
|
||||||
return result.embeddings
|
return result.embeddings
|
||||||
|
|
||||||
@ -137,7 +137,11 @@ async def _text_embed_with_vector_store(
|
|||||||
):
|
):
|
||||||
strategy_type = strategy["type"]
|
strategy_type = strategy["type"]
|
||||||
strategy_exec = load_strategy(strategy_type)
|
strategy_exec = load_strategy(strategy_type)
|
||||||
strategy_args = {**strategy}
|
strategy_config = {**strategy}
|
||||||
|
|
||||||
|
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
|
||||||
|
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
|
||||||
|
strategy_config["llm"]["max_retries"] = len(input)
|
||||||
|
|
||||||
# Get vector-storage configuration
|
# Get vector-storage configuration
|
||||||
insert_batch_size: int = (
|
insert_batch_size: int = (
|
||||||
@ -176,7 +180,7 @@ async def _text_embed_with_vector_store(
|
|||||||
texts: list[str] = batch[embed_column].to_numpy().tolist()
|
texts: list[str] = batch[embed_column].to_numpy().tolist()
|
||||||
titles: list[str] = batch[title].to_numpy().tolist()
|
titles: list[str] = batch[title].to_numpy().tolist()
|
||||||
ids: list[str] = batch[id_column].to_numpy().tolist()
|
ids: list[str] = batch[id_column].to_numpy().tolist()
|
||||||
result = await strategy_exec(texts, callbacks, cache, strategy_args)
|
result = await strategy_exec(texts, callbacks, cache, strategy_config)
|
||||||
if result.embeddings:
|
if result.embeddings:
|
||||||
embeddings = [
|
embeddings = [
|
||||||
embedding for embedding in result.embeddings if embedding is not None
|
embedding for embedding in result.embeddings if embedding is not None
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import logging
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from fnllm import EmbeddingsLLM
|
from fnllm.types import EmbeddingsLLM
|
||||||
|
|
||||||
from graphrag.cache.pipeline_cache import PipelineCache
|
from graphrag.cache.pipeline_cache import PipelineCache
|
||||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from fnllm import ChatLLM
|
from fnllm.types import ChatLLM
|
||||||
|
|
||||||
import graphrag.config.defaults as defs
|
import graphrag.config.defaults as defs
|
||||||
from graphrag.index.typing import ErrorHandlerFn
|
from graphrag.index.typing import ErrorHandlerFn
|
||||||
|
|||||||
@ -50,6 +50,10 @@ async def extract_covariates(
|
|||||||
strategy = strategy or {}
|
strategy = strategy or {}
|
||||||
strategy_config = {**strategy}
|
strategy_config = {**strategy}
|
||||||
|
|
||||||
|
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
|
||||||
|
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
|
||||||
|
strategy_config["llm"]["max_retries"] = len(input)
|
||||||
|
|
||||||
async def run_strategy(row):
|
async def run_strategy(row):
|
||||||
text = row[column]
|
text = row[column]
|
||||||
result = await run_extract_claims(
|
result = await run_extract_claims(
|
||||||
|
|||||||
@ -94,6 +94,10 @@ async def extract_graph(
|
|||||||
)
|
)
|
||||||
strategy_config = {**strategy}
|
strategy_config = {**strategy}
|
||||||
|
|
||||||
|
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
|
||||||
|
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
|
||||||
|
strategy_config["llm"]["max_retries"] = len(text_units)
|
||||||
|
|
||||||
num_started = 0
|
num_started = 0
|
||||||
|
|
||||||
async def run_strategy(row):
|
async def run_strategy(row):
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from typing import Any
|
|||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from fnllm import ChatLLM
|
from fnllm.types import ChatLLM
|
||||||
|
|
||||||
import graphrag.config.defaults as defs
|
import graphrag.config.defaults as defs
|
||||||
from graphrag.index.typing import ErrorHandlerFn
|
from graphrag.index.typing import ErrorHandlerFn
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
"""A module containing run_graph_intelligence, run_extract_graph and _create_text_splitter methods to run graph intelligence."""
|
"""A module containing run_graph_intelligence, run_extract_graph and _create_text_splitter methods to run graph intelligence."""
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from fnllm import ChatLLM
|
from fnllm.types import ChatLLM
|
||||||
|
|
||||||
import graphrag.config.defaults as defs
|
import graphrag.config.defaults as defs
|
||||||
from graphrag.cache.pipeline_cache import PipelineCache
|
from graphrag.cache.pipeline_cache import PipelineCache
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import traceback
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fnllm import ChatLLM
|
from fnllm.types import ChatLLM
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from graphrag.index.typing import ErrorHandlerFn
|
from graphrag.index.typing import ErrorHandlerFn
|
||||||
|
|||||||
@ -6,7 +6,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from fnllm import ChatLLM
|
from fnllm.types import ChatLLM
|
||||||
|
|
||||||
from graphrag.cache.pipeline_cache import PipelineCache
|
from graphrag.cache.pipeline_cache import PipelineCache
|
||||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||||
|
|||||||
@ -42,7 +42,12 @@ async def summarize_communities(
|
|||||||
"""Generate community summaries."""
|
"""Generate community summaries."""
|
||||||
reports: list[CommunityReport | None] = []
|
reports: list[CommunityReport | None] = []
|
||||||
tick = progress_ticker(callbacks.progress, len(local_contexts))
|
tick = progress_ticker(callbacks.progress, len(local_contexts))
|
||||||
runner = load_strategy(strategy["type"])
|
strategy_exec = load_strategy(strategy["type"])
|
||||||
|
strategy_config = {**strategy}
|
||||||
|
|
||||||
|
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
|
||||||
|
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
|
||||||
|
strategy_config["llm"]["max_retries"] = len(nodes)
|
||||||
|
|
||||||
community_hierarchy = restore_community_hierarchy(nodes)
|
community_hierarchy = restore_community_hierarchy(nodes)
|
||||||
levels = get_levels(nodes)
|
levels = get_levels(nodes)
|
||||||
@ -62,13 +67,13 @@ async def summarize_communities(
|
|||||||
|
|
||||||
async def run_generate(record):
|
async def run_generate(record):
|
||||||
result = await _generate_report(
|
result = await _generate_report(
|
||||||
runner,
|
strategy_exec,
|
||||||
community_id=record[schemas.COMMUNITY_ID],
|
community_id=record[schemas.COMMUNITY_ID],
|
||||||
community_level=record[schemas.COMMUNITY_LEVEL],
|
community_level=record[schemas.COMMUNITY_LEVEL],
|
||||||
community_context=record[schemas.CONTEXT_STRING],
|
community_context=record[schemas.CONTEXT_STRING],
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
cache=cache,
|
cache=cache,
|
||||||
strategy=strategy,
|
strategy=strategy_config,
|
||||||
)
|
)
|
||||||
tick()
|
tick()
|
||||||
return result
|
return result
|
||||||
|
|||||||
@ -6,7 +6,7 @@
|
|||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from fnllm import ChatLLM
|
from fnllm.types import ChatLLM
|
||||||
|
|
||||||
from graphrag.index.typing import ErrorHandlerFn
|
from graphrag.index.typing import ErrorHandlerFn
|
||||||
from graphrag.index.utils.tokens import num_tokens_from_string
|
from graphrag.index.utils.tokens import num_tokens_from_string
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
"""A module containing run_graph_intelligence, run_resolve_entities and _create_text_list_splitter methods to run graph intelligence."""
|
"""A module containing run_graph_intelligence, run_resolve_entities and _create_text_list_splitter methods to run graph intelligence."""
|
||||||
|
|
||||||
from fnllm import ChatLLM
|
from fnllm.types import ChatLLM
|
||||||
|
|
||||||
from graphrag.cache.pipeline_cache import PipelineCache
|
from graphrag.cache.pipeline_cache import PipelineCache
|
||||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||||
|
|||||||
@ -76,6 +76,10 @@ async def summarize_descriptions(
|
|||||||
)
|
)
|
||||||
strategy_config = {**strategy}
|
strategy_config = {**strategy}
|
||||||
|
|
||||||
|
# if max_retries is not set, inject a dynamically assigned value based on the maximum number of expected LLM calls to be made
|
||||||
|
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
|
||||||
|
strategy_config["llm"]["max_retries"] = len(entities_df) + len(relationships_df)
|
||||||
|
|
||||||
async def get_summarized(
|
async def get_summarized(
|
||||||
nodes: pd.DataFrame, edges: pd.DataFrame, semaphore: asyncio.Semaphore
|
nodes: pd.DataFrame, edges: pd.DataFrame, semaphore: asyncio.Semaphore
|
||||||
):
|
):
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import graphrag.config.defaults as defs
|
||||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||||
from graphrag.index.llm.load_llm import load_llm, load_llm_embeddings
|
from graphrag.index.llm.load_llm import load_llm, load_llm_embeddings
|
||||||
@ -17,6 +18,9 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
|
|||||||
# Validate Chat LLM configs
|
# Validate Chat LLM configs
|
||||||
# TODO: Replace default_chat_model with a way to select the model
|
# TODO: Replace default_chat_model with a way to select the model
|
||||||
default_llm_settings = parameters.get_language_model_config("default_chat_model")
|
default_llm_settings = parameters.get_language_model_config("default_chat_model")
|
||||||
|
# if max_retries is not set, set it to the default value
|
||||||
|
if default_llm_settings.max_retries == -1:
|
||||||
|
default_llm_settings.max_retries = defs.LLM_MAX_RETRIES
|
||||||
llm = load_llm(
|
llm = load_llm(
|
||||||
name="test-llm",
|
name="test-llm",
|
||||||
config=default_llm_settings,
|
config=default_llm_settings,
|
||||||
|
|||||||
@ -39,7 +39,7 @@ async def run_workflow(
|
|||||||
config.community_reports.model_id
|
config.community_reports.model_id
|
||||||
)
|
)
|
||||||
async_mode = community_reports_llm_settings.async_mode
|
async_mode = community_reports_llm_settings.async_mode
|
||||||
num_threads = community_reports_llm_settings.parallelization_num_threads
|
num_threads = community_reports_llm_settings.concurrent_requests
|
||||||
summarization_strategy = config.community_reports.resolved_strategy(
|
summarization_strategy = config.community_reports.resolved_strategy(
|
||||||
config.root_dir, community_reports_llm_settings
|
config.root_dir, community_reports_llm_settings
|
||||||
)
|
)
|
||||||
|
|||||||
@ -31,7 +31,7 @@ async def run_workflow(
|
|||||||
config.community_reports.model_id
|
config.community_reports.model_id
|
||||||
)
|
)
|
||||||
async_mode = community_reports_llm_settings.async_mode
|
async_mode = community_reports_llm_settings.async_mode
|
||||||
num_threads = community_reports_llm_settings.parallelization_num_threads
|
num_threads = community_reports_llm_settings.concurrent_requests
|
||||||
summarization_strategy = config.community_reports.resolved_strategy(
|
summarization_strategy = config.community_reports.resolved_strategy(
|
||||||
config.root_dir, community_reports_llm_settings
|
config.root_dir, community_reports_llm_settings
|
||||||
)
|
)
|
||||||
|
|||||||
@ -32,7 +32,7 @@ async def run_workflow(
|
|||||||
)
|
)
|
||||||
|
|
||||||
async_mode = extract_claims_llm_settings.async_mode
|
async_mode = extract_claims_llm_settings.async_mode
|
||||||
num_threads = extract_claims_llm_settings.parallelization_num_threads
|
num_threads = extract_claims_llm_settings.concurrent_requests
|
||||||
|
|
||||||
output = await extract_covariates(
|
output = await extract_covariates(
|
||||||
text_units,
|
text_units,
|
||||||
|
|||||||
@ -32,9 +32,6 @@ async def run_workflow(
|
|||||||
extraction_strategy = config.extract_graph.resolved_strategy(
|
extraction_strategy = config.extract_graph.resolved_strategy(
|
||||||
config.root_dir, extract_graph_llm_settings
|
config.root_dir, extract_graph_llm_settings
|
||||||
)
|
)
|
||||||
extraction_num_threads = extract_graph_llm_settings.parallelization_num_threads
|
|
||||||
extraction_async_mode = extract_graph_llm_settings.async_mode
|
|
||||||
entity_types = config.extract_graph.entity_types
|
|
||||||
|
|
||||||
summarization_llm_settings = config.get_language_model_config(
|
summarization_llm_settings = config.get_language_model_config(
|
||||||
config.summarize_descriptions.model_id
|
config.summarize_descriptions.model_id
|
||||||
@ -42,18 +39,17 @@ async def run_workflow(
|
|||||||
summarization_strategy = config.summarize_descriptions.resolved_strategy(
|
summarization_strategy = config.summarize_descriptions.resolved_strategy(
|
||||||
config.root_dir, summarization_llm_settings
|
config.root_dir, summarization_llm_settings
|
||||||
)
|
)
|
||||||
summarization_num_threads = summarization_llm_settings.parallelization_num_threads
|
|
||||||
|
|
||||||
entities, relationships = await extract_graph(
|
entities, relationships = await extract_graph(
|
||||||
text_units=text_units,
|
text_units=text_units,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
cache=context.cache,
|
cache=context.cache,
|
||||||
extraction_strategy=extraction_strategy,
|
extraction_strategy=extraction_strategy,
|
||||||
extraction_num_threads=extraction_num_threads,
|
extraction_num_threads=extract_graph_llm_settings.concurrent_requests,
|
||||||
extraction_async_mode=extraction_async_mode,
|
extraction_async_mode=extract_graph_llm_settings.async_mode,
|
||||||
entity_types=entity_types,
|
entity_types=config.extract_graph.entity_types,
|
||||||
summarization_strategy=summarization_strategy,
|
summarization_strategy=summarization_strategy,
|
||||||
summarization_num_threads=summarization_num_threads,
|
summarization_num_threads=summarization_llm_settings.concurrent_requests,
|
||||||
embed_config=config.embed_graph,
|
embed_config=config.embed_graph,
|
||||||
layout_enabled=config.umap.enabled,
|
layout_enabled=config.umap.enabled,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
# Copyright (c) 2024 Microsoft Corporation.
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
# Licensed under the MIT License
|
# Licensed under the MIT License
|
||||||
|
|
||||||
from fnllm import ChatLLM
|
from fnllm.types import ChatLLM
|
||||||
|
|
||||||
from graphrag.prompt_tune.prompt.community_report_rating import (
|
from graphrag.prompt_tune.prompt.community_report_rating import (
|
||||||
GENERATE_REPORT_RATING_PROMPT,
|
GENERATE_REPORT_RATING_PROMPT,
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
"""Generate a community reporter role for community summarization."""
|
"""Generate a community reporter role for community summarization."""
|
||||||
|
|
||||||
from fnllm import ChatLLM
|
from fnllm.types import ChatLLM
|
||||||
|
|
||||||
from graphrag.prompt_tune.prompt.community_reporter_role import (
|
from graphrag.prompt_tune.prompt.community_reporter_role import (
|
||||||
GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT,
|
GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT,
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
"""Domain generation for GraphRAG prompts."""
|
"""Domain generation for GraphRAG prompts."""
|
||||||
|
|
||||||
from fnllm import ChatLLM
|
from fnllm.types import ChatLLM
|
||||||
|
|
||||||
from graphrag.prompt_tune.prompt.domain import GENERATE_DOMAIN_PROMPT
|
from graphrag.prompt_tune.prompt.domain import GENERATE_DOMAIN_PROMPT
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from fnllm import ChatLLM
|
from fnllm.types import ChatLLM
|
||||||
|
|
||||||
from graphrag.prompt_tune.prompt.entity_relationship import (
|
from graphrag.prompt_tune.prompt.entity_relationship import (
|
||||||
ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT,
|
ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT,
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
"""Entity type generation module for fine-tuning."""
|
"""Entity type generation module for fine-tuning."""
|
||||||
|
|
||||||
from fnllm import ChatLLM
|
from fnllm.types import ChatLLM
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from graphrag.prompt_tune.defaults import DEFAULT_TASK
|
from graphrag.prompt_tune.defaults import DEFAULT_TASK
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
"""Language detection for GraphRAG prompts."""
|
"""Language detection for GraphRAG prompts."""
|
||||||
|
|
||||||
from fnllm import ChatLLM
|
from fnllm.types import ChatLLM
|
||||||
|
|
||||||
from graphrag.prompt_tune.prompt.language import DETECT_LANGUAGE_PROMPT
|
from graphrag.prompt_tune.prompt.language import DETECT_LANGUAGE_PROMPT
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
"""Persona generating module for fine-tuning GraphRAG prompts."""
|
"""Persona generating module for fine-tuning GraphRAG prompts."""
|
||||||
|
|
||||||
from fnllm import ChatLLM
|
from fnllm.types import ChatLLM
|
||||||
|
|
||||||
from graphrag.prompt_tune.defaults import DEFAULT_TASK
|
from graphrag.prompt_tune.defaults import DEFAULT_TASK
|
||||||
from graphrag.prompt_tune.prompt.persona import GENERATE_PERSONA_PROMPT
|
from graphrag.prompt_tune.prompt.persona import GENERATE_PERSONA_PROMPT
|
||||||
|
|||||||
@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from fnllm import ChatLLM
|
from fnllm.types import ChatLLM
|
||||||
|
|
||||||
import graphrag.config.defaults as defs
|
import graphrag.config.defaults as defs
|
||||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||||
|
|
||||||
|
import graphrag.config.defaults as defs
|
||||||
from graphrag.config.enums import AuthType, LLMType
|
from graphrag.config.enums import AuthType, LLMType
|
||||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||||
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
|
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
|
||||||
@ -14,67 +15,68 @@ from graphrag.query.llm.oai.typing import OpenaiApiType
|
|||||||
|
|
||||||
def get_llm(config: GraphRagConfig) -> ChatOpenAI:
|
def get_llm(config: GraphRagConfig) -> ChatOpenAI:
|
||||||
"""Get the LLM client."""
|
"""Get the LLM client."""
|
||||||
default_llm_settings = config.get_language_model_config("default_chat_model")
|
llm_config = config.get_language_model_config("default_chat_model")
|
||||||
is_azure_client = default_llm_settings.type == LLMType.AzureOpenAIChat
|
is_azure_client = llm_config.type == LLMType.AzureOpenAIChat
|
||||||
debug_llm_key = default_llm_settings.api_key or ""
|
debug_llm_key = llm_config.api_key or ""
|
||||||
llm_debug_info = {
|
llm_debug_info = {
|
||||||
**default_llm_settings.model_dump(),
|
**llm_config.model_dump(),
|
||||||
"api_key": f"REDACTED,len={len(debug_llm_key)}",
|
"api_key": f"REDACTED,len={len(debug_llm_key)}",
|
||||||
}
|
}
|
||||||
audience = (
|
audience = (
|
||||||
default_llm_settings.audience
|
llm_config.audience
|
||||||
if default_llm_settings.audience
|
if llm_config.audience
|
||||||
else "https://cognitiveservices.azure.com/.default"
|
else "https://cognitiveservices.azure.com/.default"
|
||||||
)
|
)
|
||||||
print(f"creating llm client with {llm_debug_info}") # noqa T201
|
print(f"creating llm client with {llm_debug_info}") # noqa T201
|
||||||
return ChatOpenAI(
|
return ChatOpenAI(
|
||||||
api_key=default_llm_settings.api_key,
|
api_key=llm_config.api_key,
|
||||||
azure_ad_token_provider=(
|
azure_ad_token_provider=(
|
||||||
get_bearer_token_provider(DefaultAzureCredential(), audience)
|
get_bearer_token_provider(DefaultAzureCredential(), audience)
|
||||||
if is_azure_client
|
if is_azure_client and llm_config.auth_type == AuthType.AzureManagedIdentity
|
||||||
and default_llm_settings.auth_type == AuthType.AzureManagedIdentity
|
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
api_base=default_llm_settings.api_base,
|
api_base=llm_config.api_base,
|
||||||
organization=default_llm_settings.organization,
|
organization=llm_config.organization,
|
||||||
model=default_llm_settings.model,
|
model=llm_config.model,
|
||||||
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
|
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
|
||||||
deployment_name=default_llm_settings.deployment_name,
|
deployment_name=llm_config.deployment_name,
|
||||||
api_version=default_llm_settings.api_version,
|
api_version=llm_config.api_version,
|
||||||
max_retries=default_llm_settings.max_retries,
|
max_retries=llm_config.max_retries
|
||||||
request_timeout=default_llm_settings.request_timeout,
|
if llm_config.max_retries != -1
|
||||||
|
else defs.LLM_MAX_RETRIES,
|
||||||
|
request_timeout=llm_config.request_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
|
def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
|
||||||
"""Get the LLM client for embeddings."""
|
"""Get the LLM client for embeddings."""
|
||||||
embeddings_llm_settings = config.get_language_model_config(
|
embeddings_llm_config = config.get_language_model_config(config.embed_text.model_id)
|
||||||
config.embed_text.model_id
|
is_azure_client = embeddings_llm_config.type == LLMType.AzureOpenAIEmbedding
|
||||||
)
|
debug_embedding_api_key = embeddings_llm_config.api_key or ""
|
||||||
is_azure_client = embeddings_llm_settings.type == LLMType.AzureOpenAIEmbedding
|
|
||||||
debug_embedding_api_key = embeddings_llm_settings.api_key or ""
|
|
||||||
llm_debug_info = {
|
llm_debug_info = {
|
||||||
**embeddings_llm_settings.model_dump(),
|
**embeddings_llm_config.model_dump(),
|
||||||
"api_key": f"REDACTED,len={len(debug_embedding_api_key)}",
|
"api_key": f"REDACTED,len={len(debug_embedding_api_key)}",
|
||||||
}
|
}
|
||||||
if embeddings_llm_settings.audience is None:
|
if embeddings_llm_config.audience is None:
|
||||||
audience = "https://cognitiveservices.azure.com/.default"
|
audience = "https://cognitiveservices.azure.com/.default"
|
||||||
else:
|
else:
|
||||||
audience = embeddings_llm_settings.audience
|
audience = embeddings_llm_config.audience
|
||||||
print(f"creating embedding llm client with {llm_debug_info}") # noqa T201
|
print(f"creating embedding llm client with {llm_debug_info}") # noqa T201
|
||||||
return OpenAIEmbedding(
|
return OpenAIEmbedding(
|
||||||
api_key=embeddings_llm_settings.api_key,
|
api_key=embeddings_llm_config.api_key,
|
||||||
azure_ad_token_provider=(
|
azure_ad_token_provider=(
|
||||||
get_bearer_token_provider(DefaultAzureCredential(), audience)
|
get_bearer_token_provider(DefaultAzureCredential(), audience)
|
||||||
if is_azure_client
|
if is_azure_client
|
||||||
and embeddings_llm_settings.auth_type == AuthType.AzureManagedIdentity
|
and embeddings_llm_config.auth_type == AuthType.AzureManagedIdentity
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
api_base=embeddings_llm_settings.api_base,
|
api_base=embeddings_llm_config.api_base,
|
||||||
organization=embeddings_llm_settings.organization,
|
organization=embeddings_llm_config.organization,
|
||||||
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
|
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
|
||||||
model=embeddings_llm_settings.model,
|
model=embeddings_llm_config.model,
|
||||||
deployment_name=embeddings_llm_settings.deployment_name,
|
deployment_name=embeddings_llm_config.deployment_name,
|
||||||
api_version=embeddings_llm_settings.api_version,
|
api_version=embeddings_llm_config.api_version,
|
||||||
max_retries=embeddings_llm_settings.max_retries,
|
max_retries=embeddings_llm_config.max_retries
|
||||||
|
if embeddings_llm_config.max_retries != -1
|
||||||
|
else defs.LLM_MAX_RETRIES,
|
||||||
)
|
)
|
||||||
|
|||||||
864
poetry.lock
generated
864
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -57,11 +57,15 @@ lancedb = "^0.17.0"
|
|||||||
aiofiles = "^24.1.0"
|
aiofiles = "^24.1.0"
|
||||||
|
|
||||||
# LLM
|
# LLM
|
||||||
|
fnllm = {extras = ["azure", "openai"], version = "^0.1.2"}
|
||||||
|
httpx = "^0.28.1"
|
||||||
|
json-repair = "^0.30.3"
|
||||||
openai = "^1.57.0"
|
openai = "^1.57.0"
|
||||||
nltk = "3.9.1"
|
nltk = "3.9.1"
|
||||||
|
tenacity = "^9.0.0"
|
||||||
tiktoken = "^0.8.0"
|
tiktoken = "^0.8.0"
|
||||||
|
|
||||||
# Data-Sci
|
# Data-Science
|
||||||
numpy = "^1.25.2"
|
numpy = "^1.25.2"
|
||||||
graspologic = "^3.4.1"
|
graspologic = "^3.4.1"
|
||||||
networkx = "^3.4.2"
|
networkx = "^3.4.2"
|
||||||
@ -78,22 +82,18 @@ rich = "^13.9.4"
|
|||||||
devtools = "^0.12.2"
|
devtools = "^0.12.2"
|
||||||
typing-extensions = "^4.12.2"
|
typing-extensions = "^4.12.2"
|
||||||
|
|
||||||
#Azure
|
# Azure
|
||||||
azure-cosmos = "^4.9.0"
|
azure-cosmos = "^4.9.0"
|
||||||
azure-identity = "^1.19.0"
|
azure-identity = "^1.19.0"
|
||||||
azure-storage-blob = "^12.24.0"
|
azure-storage-blob = "^12.24.0"
|
||||||
|
|
||||||
future = "^1.0.0" # Needed until graspologic fixes their dependency
|
future = "^1.0.0" # Needed until graspologic fixes their dependency
|
||||||
typer = "^0.15.1"
|
typer = "^0.15.1"
|
||||||
fnllm = "^0.0.10"
|
|
||||||
|
|
||||||
tenacity = "^9.0.0"
|
|
||||||
json-repair = "^0.30.3"
|
|
||||||
tqdm = "^4.67.1"
|
tqdm = "^4.67.1"
|
||||||
httpx = "^0.28.1"
|
|
||||||
|
|
||||||
textblob = "^0.18.0.post0"
|
textblob = "^0.18.0.post0"
|
||||||
spacy = "^3.8.4"
|
spacy = "^3.8.4"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
coverage = "^7.6.9"
|
coverage = "^7.6.9"
|
||||||
ipykernel = "^6.29.5"
|
ipykernel = "^6.29.5"
|
||||||
|
|||||||
2
tests/fixtures/min-csv/config.json
vendored
2
tests/fixtures/min-csv/config.json
vendored
@ -15,7 +15,7 @@
|
|||||||
1,
|
1,
|
||||||
2500
|
2500
|
||||||
],
|
],
|
||||||
"max_runtime": 300,
|
"max_runtime": 500,
|
||||||
"expected_artifacts": 2
|
"expected_artifacts": 2
|
||||||
},
|
},
|
||||||
"create_communities": {
|
"create_communities": {
|
||||||
|
|||||||
6
tests/fixtures/min-csv/settings.yml
vendored
6
tests/fixtures/min-csv/settings.yml
vendored
@ -10,8 +10,7 @@ models:
|
|||||||
tokens_per_minute: ${GRAPHRAG_LLM_TPM}
|
tokens_per_minute: ${GRAPHRAG_LLM_TPM}
|
||||||
requests_per_minute: ${GRAPHRAG_LLM_RPM}
|
requests_per_minute: ${GRAPHRAG_LLM_RPM}
|
||||||
model_supports_json: true
|
model_supports_json: true
|
||||||
parallelization_num_threads: 50
|
concurrent_requests: 50
|
||||||
parallelization_stagger: 0.3
|
|
||||||
async_mode: threaded
|
async_mode: threaded
|
||||||
default_embedding_model:
|
default_embedding_model:
|
||||||
azure_auth_type: api_key
|
azure_auth_type: api_key
|
||||||
@ -23,8 +22,7 @@ models:
|
|||||||
model: ${GRAPHRAG_EMBEDDING_MODEL}
|
model: ${GRAPHRAG_EMBEDDING_MODEL}
|
||||||
tokens_per_minute: ${GRAPHRAG_EMBEDDING_TPM}
|
tokens_per_minute: ${GRAPHRAG_EMBEDDING_TPM}
|
||||||
requests_per_minute: ${GRAPHRAG_EMBEDDING_RPM}
|
requests_per_minute: ${GRAPHRAG_EMBEDDING_RPM}
|
||||||
parallelization_num_threads: 50
|
concurrent_requests: 50
|
||||||
parallelization_stagger: 0.3
|
|
||||||
async_mode: threaded
|
async_mode: threaded
|
||||||
|
|
||||||
vector_store:
|
vector_store:
|
||||||
|
|||||||
2
tests/fixtures/text/config.json
vendored
2
tests/fixtures/text/config.json
vendored
@ -15,7 +15,7 @@
|
|||||||
1,
|
1,
|
||||||
2500
|
2500
|
||||||
],
|
],
|
||||||
"max_runtime": 300,
|
"max_runtime": 500,
|
||||||
"expected_artifacts": 2
|
"expected_artifacts": 2
|
||||||
},
|
},
|
||||||
"extract_covariates": {
|
"extract_covariates": {
|
||||||
|
|||||||
6
tests/fixtures/text/settings.yml
vendored
6
tests/fixtures/text/settings.yml
vendored
@ -10,8 +10,7 @@ models:
|
|||||||
tokens_per_minute: ${GRAPHRAG_LLM_TPM}
|
tokens_per_minute: ${GRAPHRAG_LLM_TPM}
|
||||||
requests_per_minute: ${GRAPHRAG_LLM_RPM}
|
requests_per_minute: ${GRAPHRAG_LLM_RPM}
|
||||||
model_supports_json: true
|
model_supports_json: true
|
||||||
parallelization_num_threads: 50
|
concurrent_requests: 50
|
||||||
parallelization_stagger: 0.3
|
|
||||||
async_mode: threaded
|
async_mode: threaded
|
||||||
default_embedding_model:
|
default_embedding_model:
|
||||||
azure_auth_type: api_key
|
azure_auth_type: api_key
|
||||||
@ -23,8 +22,7 @@ models:
|
|||||||
model: ${GRAPHRAG_EMBEDDING_MODEL}
|
model: ${GRAPHRAG_EMBEDDING_MODEL}
|
||||||
tokens_per_minute: ${GRAPHRAG_EMBEDDING_TPM}
|
tokens_per_minute: ${GRAPHRAG_EMBEDDING_TPM}
|
||||||
requests_per_minute: ${GRAPHRAG_EMBEDDING_RPM}
|
requests_per_minute: ${GRAPHRAG_EMBEDDING_RPM}
|
||||||
parallelization_num_threads: 50
|
concurrent_requests: 50
|
||||||
parallelization_stagger: 0.3
|
|
||||||
async_mode: threaded
|
async_mode: threaded
|
||||||
|
|
||||||
vector_store:
|
vector_store:
|
||||||
|
|||||||
@ -265,13 +265,7 @@ def assert_language_model_configs(
|
|||||||
assert actual.requests_per_minute == expected.requests_per_minute
|
assert actual.requests_per_minute == expected.requests_per_minute
|
||||||
assert actual.max_retries == expected.max_retries
|
assert actual.max_retries == expected.max_retries
|
||||||
assert actual.max_retry_wait == expected.max_retry_wait
|
assert actual.max_retry_wait == expected.max_retry_wait
|
||||||
assert (
|
|
||||||
actual.sleep_on_rate_limit_recommendation
|
|
||||||
== expected.sleep_on_rate_limit_recommendation
|
|
||||||
)
|
|
||||||
assert actual.concurrent_requests == expected.concurrent_requests
|
assert actual.concurrent_requests == expected.concurrent_requests
|
||||||
assert actual.parallelization_stagger == expected.parallelization_stagger
|
|
||||||
assert actual.parallelization_num_threads == expected.parallelization_num_threads
|
|
||||||
assert actual.async_mode == expected.async_mode
|
assert actual.async_mode == expected.async_mode
|
||||||
if actual.responses is not None:
|
if actual.responses is not None:
|
||||||
assert expected.responses is not None
|
assert expected.responses is not None
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) 2024 Microsoft Corporation.
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
# Licensed under the MIT License
|
# Licensed under the MIT License
|
||||||
from fnllm import ChatLLM
|
from fnllm.types import ChatLLM
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from graphrag.index.llm.mock_llm import MockChatLLM
|
from graphrag.index.llm.mock_llm import MockChatLLM
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user