mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-13 16:47:20 +08:00
Tokenizer (#2051)
* Add LiteLLM chat and embedding model providers. * Fix code review findings. * Add litellm. * Fix formatting. * Update dictionary. * Update litellm. * Fix embedding. * Remove manual use of tiktoken and replace with Tokenizer interface. Adds support for encoding and decoding the models supported by litellm. * Update litellm. * Configure litellm to drop unsupported params. * Cleanup semversioner release notes. * Add num_tokens util to Tokenizer interface. * Update litellm service factories. * Cleanup litellm chat/embedding model argument assignment. * Update chat and embedding type field for litellm use and future migration away from fnllm. * Flatten litellm service organization. * Update litellm. * Update litellm factory validation. * Flatten litellm rate limit service organization. * Update rate limiter - disable with None/null instead of 0. * Fix usage of get_tokenizer. * Update litellm service registrations. * Add jitter to exponential retry. * Update validation. * Update validation. * Add litellm request logging layer. * Update cache key. * Update defaults. --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
parent
82cd3b7df2
commit
2b70e4a4f3
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "minor",
|
||||
"description": "Add LiteLLM chat and embedding model providers."
|
||||
}
|
||||
@ -81,6 +81,7 @@ typer
|
||||
spacy
|
||||
kwargs
|
||||
ollama
|
||||
litellm
|
||||
|
||||
# Library Methods
|
||||
iterrows
|
||||
@ -103,6 +104,8 @@ isin
|
||||
nocache
|
||||
nbconvert
|
||||
levelno
|
||||
acompletion
|
||||
aembedding
|
||||
|
||||
# HTML
|
||||
nbsp
|
||||
|
||||
@ -47,6 +47,7 @@ from graphrag.prompt_tune.generator.language import detect_language
|
||||
from graphrag.prompt_tune.generator.persona import generate_persona
|
||||
from graphrag.prompt_tune.loader.input import load_docs_in_chunks
|
||||
from graphrag.prompt_tune.types import DocSelectionType
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -166,7 +167,7 @@ async def generate_indexing_prompts(
|
||||
examples=examples,
|
||||
language=language,
|
||||
json_mode=False, # config.llm.model_supports_json should be used, but these prompts are used in non-json mode by the index engine
|
||||
encoding_model=extract_graph_llm_settings.encoding_model,
|
||||
tokenizer=get_tokenizer(model_config=extract_graph_llm_settings),
|
||||
max_token_count=max_tokens,
|
||||
min_examples_required=min_examples_required,
|
||||
)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
"""Common default configuration values."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal
|
||||
@ -23,6 +24,25 @@ from graphrag.config.enums import (
|
||||
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
|
||||
EN_STOP_WORDS,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter import (
|
||||
RateLimiter,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.rate_limiter.static_rate_limiter import (
|
||||
StaticRateLimiter,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.retry.exponential_retry import (
|
||||
ExponentialRetry,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.retry.incremental_wait_retry import (
|
||||
IncrementalWaitRetry,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.retry.native_wait_retry import (
|
||||
NativeRetry,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.retry.random_wait_retry import (
|
||||
RandomWaitRetry,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.retry.retry import Retry
|
||||
|
||||
DEFAULT_OUTPUT_BASE_DIR = "output"
|
||||
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
|
||||
@ -39,6 +59,18 @@ ENCODING_MODEL = "cl100k_base"
|
||||
COGNITIVE_SERVICES_AUDIENCE = "https://cognitiveservices.azure.com/.default"
|
||||
|
||||
|
||||
DEFAULT_RETRY_SERVICES: dict[str, Callable[..., Retry]] = {
|
||||
"native": NativeRetry,
|
||||
"exponential_backoff": ExponentialRetry,
|
||||
"random_wait": RandomWaitRetry,
|
||||
"incremental_wait": IncrementalWaitRetry,
|
||||
}
|
||||
|
||||
DEFAULT_RATE_LIMITER_SERVICES: dict[str, Callable[..., RateLimiter]] = {
|
||||
"static": StaticRateLimiter,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicSearchDefaults:
|
||||
"""Default values for basic search."""
|
||||
@ -275,6 +307,7 @@ class LanguageModelDefaults:
|
||||
|
||||
api_key: None = None
|
||||
auth_type: ClassVar[AuthType] = AuthType.APIKey
|
||||
model_provider: str | None = None
|
||||
encoding_model: str = ""
|
||||
max_tokens: int | None = None
|
||||
temperature: float = 0
|
||||
@ -294,6 +327,7 @@ class LanguageModelDefaults:
|
||||
model_supports_json: None = None
|
||||
tokens_per_minute: Literal["auto"] = "auto"
|
||||
requests_per_minute: Literal["auto"] = "auto"
|
||||
rate_limit_strategy: str | None = "static"
|
||||
retry_strategy: str = "native"
|
||||
max_retries: int = 10
|
||||
max_retry_wait: float = 10.0
|
||||
|
||||
@ -86,10 +86,12 @@ class ModelType(str, Enum):
|
||||
# Embeddings
|
||||
OpenAIEmbedding = "openai_embedding"
|
||||
AzureOpenAIEmbedding = "azure_openai_embedding"
|
||||
Embedding = "embedding"
|
||||
|
||||
# Chat Completion
|
||||
OpenAIChat = "openai_chat"
|
||||
AzureOpenAIChat = "azure_openai_chat"
|
||||
Chat = "chat"
|
||||
|
||||
# Debug
|
||||
MockChat = "mock_chat"
|
||||
|
||||
@ -37,6 +37,12 @@ from graphrag.config.models.summarize_descriptions_config import (
|
||||
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
|
||||
from graphrag.config.models.umap_config import UmapConfig
|
||||
from graphrag.config.models.vector_store_config import VectorStoreConfig
|
||||
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter_factory import (
|
||||
RateLimiterFactory,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.retry.retry_factory import (
|
||||
RetryFactory,
|
||||
)
|
||||
|
||||
|
||||
class GraphRagConfig(BaseModel):
|
||||
@ -89,6 +95,47 @@ class GraphRagConfig(BaseModel):
|
||||
if defs.DEFAULT_EMBEDDING_MODEL_ID not in self.models:
|
||||
raise LanguageModelConfigMissingError(defs.DEFAULT_EMBEDDING_MODEL_ID)
|
||||
|
||||
def _validate_retry_services(self) -> None:
|
||||
"""Validate the retry services configuration."""
|
||||
retry_factory = RetryFactory()
|
||||
|
||||
for model_id, model in self.models.items():
|
||||
if model.retry_strategy != "none":
|
||||
if model.retry_strategy not in retry_factory:
|
||||
msg = f"Retry strategy '{model.retry_strategy}' for model '{model_id}' is not registered. Available strategies: {', '.join(retry_factory.keys())}"
|
||||
raise ValueError(msg)
|
||||
|
||||
_ = retry_factory.create(
|
||||
strategy=model.retry_strategy,
|
||||
max_attempts=model.max_retries,
|
||||
max_retry_wait=model.max_retry_wait,
|
||||
)
|
||||
|
||||
def _validate_rate_limiter_services(self) -> None:
|
||||
"""Validate the rate limiter services configuration."""
|
||||
rate_limiter_factory = RateLimiterFactory()
|
||||
|
||||
for model_id, model in self.models.items():
|
||||
if model.rate_limit_strategy is not None:
|
||||
if model.rate_limit_strategy not in rate_limiter_factory:
|
||||
msg = f"Rate Limiter strategy '{model.rate_limit_strategy}' for model '{model_id}' is not registered. Available strategies: {', '.join(rate_limiter_factory.keys())}"
|
||||
raise ValueError(msg)
|
||||
|
||||
rpm = (
|
||||
model.requests_per_minute
|
||||
if type(model.requests_per_minute) is int
|
||||
else None
|
||||
)
|
||||
tpm = (
|
||||
model.tokens_per_minute
|
||||
if type(model.tokens_per_minute) is int
|
||||
else None
|
||||
)
|
||||
if rpm is not None or tpm is not None:
|
||||
_ = rate_limiter_factory.create(
|
||||
strategy=model.rate_limit_strategy, rpm=rpm, tpm=tpm
|
||||
)
|
||||
|
||||
input: InputConfig = Field(
|
||||
description="The input configuration.", default=InputConfig()
|
||||
)
|
||||
@ -300,6 +347,11 @@ class GraphRagConfig(BaseModel):
|
||||
raise ValueError(msg)
|
||||
store.db_uri = str((Path(self.root_dir) / store.db_uri).resolve())
|
||||
|
||||
def _validate_factories(self) -> None:
|
||||
"""Validate the factories used in the configuration."""
|
||||
self._validate_retry_services()
|
||||
self._validate_rate_limiter_services()
|
||||
|
||||
def get_language_model_config(self, model_id: str) -> LanguageModelConfig:
|
||||
"""Get a model configuration by ID.
|
||||
|
||||
@ -360,4 +412,5 @@ class GraphRagConfig(BaseModel):
|
||||
self._validate_multi_output_base_dirs()
|
||||
self._validate_update_index_output_base_dir()
|
||||
self._validate_vector_store_db_uri()
|
||||
self._validate_factories()
|
||||
return self
|
||||
|
||||
@ -73,8 +73,11 @@ class LanguageModelConfig(BaseModel):
|
||||
ConflictingSettingsError
|
||||
If the Azure authentication type conflicts with the model being used.
|
||||
"""
|
||||
if self.auth_type == AuthType.AzureManagedIdentity and (
|
||||
self.type == ModelType.OpenAIChat or self.type == ModelType.OpenAIEmbedding
|
||||
if (
|
||||
self.auth_type == AuthType.AzureManagedIdentity
|
||||
and self.type != ModelType.AzureOpenAIChat
|
||||
and self.type != ModelType.AzureOpenAIEmbedding
|
||||
and self.model_provider != "azure" # indicates Litellm + AOI
|
||||
):
|
||||
msg = f"auth_type of azure_managed_identity is not supported for model type {self.type}. Please rerun `graphrag init` and set the auth_type to api_key."
|
||||
raise ConflictingSettingsError(msg)
|
||||
@ -94,6 +97,27 @@ class LanguageModelConfig(BaseModel):
|
||||
msg = f"Model type {self.type} is not recognized, must be one of {ModelFactory.get_chat_models() + ModelFactory.get_embedding_models()}."
|
||||
raise KeyError(msg)
|
||||
|
||||
model_provider: str | None = Field(
|
||||
description="The model provider to use.",
|
||||
default=language_model_defaults.model_provider,
|
||||
)
|
||||
|
||||
def _validate_model_provider(self) -> None:
|
||||
"""Validate the model provider.
|
||||
|
||||
Required when using Litellm.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If the model provider is not recognized.
|
||||
"""
|
||||
if (self.type == ModelType.Chat or self.type == ModelType.Embedding) and (
|
||||
self.model_provider is None or self.model_provider.strip() == ""
|
||||
):
|
||||
msg = f"Model provider must be specified when using type == {self.type}."
|
||||
raise KeyError(msg)
|
||||
|
||||
model: str = Field(description="The LLM model to use.")
|
||||
encoding_model: str = Field(
|
||||
description="The encoding model to use",
|
||||
@ -103,12 +127,27 @@ class LanguageModelConfig(BaseModel):
|
||||
def _validate_encoding_model(self) -> None:
|
||||
"""Validate the encoding model.
|
||||
|
||||
The default behavior is to use an encoding model that matches the LLM model.
|
||||
LiteLLM supports 100+ models and their tokenization. There is no need to
|
||||
set the encoding model when using the new LiteLLM provider as was done with fnllm provider.
|
||||
|
||||
Users can still manually specify a tiktoken based encoding model to use even with the LiteLLM provider
|
||||
in which case the specified encoding model will be used regardless of the LLM model being used, even if
|
||||
it is not an openai based model.
|
||||
|
||||
If not using LiteLLM provider, set the encoding model based on the LLM model name.
|
||||
This is for backward compatibility with existing fnllm provider until fnllm is removed.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If the model name is not recognized.
|
||||
"""
|
||||
if self.encoding_model.strip() == "":
|
||||
if (
|
||||
self.type != ModelType.Chat
|
||||
and self.type != ModelType.Embedding
|
||||
and self.encoding_model.strip() == ""
|
||||
):
|
||||
self.encoding_model = tiktoken.encoding_name_for_model(self.model)
|
||||
|
||||
api_base: str | None = Field(
|
||||
@ -129,6 +168,7 @@ class LanguageModelConfig(BaseModel):
|
||||
if (
|
||||
self.type == ModelType.AzureOpenAIChat
|
||||
or self.type == ModelType.AzureOpenAIEmbedding
|
||||
or self.model_provider == "azure" # indicates Litellm + AOI
|
||||
) and (self.api_base is None or self.api_base.strip() == ""):
|
||||
raise AzureApiBaseMissingError(self.type)
|
||||
|
||||
@ -150,6 +190,7 @@ class LanguageModelConfig(BaseModel):
|
||||
if (
|
||||
self.type == ModelType.AzureOpenAIChat
|
||||
or self.type == ModelType.AzureOpenAIEmbedding
|
||||
or self.model_provider == "azure" # indicates Litellm + AOI
|
||||
) and (self.api_version is None or self.api_version.strip() == ""):
|
||||
raise AzureApiVersionMissingError(self.type)
|
||||
|
||||
@ -171,6 +212,7 @@ class LanguageModelConfig(BaseModel):
|
||||
if (
|
||||
self.type == ModelType.AzureOpenAIChat
|
||||
or self.type == ModelType.AzureOpenAIEmbedding
|
||||
or self.model_provider == "azure" # indicates Litellm + AOI
|
||||
) and (self.deployment_name is None or self.deployment_name.strip() == ""):
|
||||
raise AzureDeploymentNameMissingError(self.type)
|
||||
|
||||
@ -212,6 +254,14 @@ class LanguageModelConfig(BaseModel):
|
||||
msg = f"Tokens per minute must be a non zero positive number, 'auto' or null. Suggested value: {language_model_defaults.tokens_per_minute}."
|
||||
raise ValueError(msg)
|
||||
|
||||
if (
|
||||
(self.type == ModelType.Chat or self.type == ModelType.Embedding)
|
||||
and self.rate_limit_strategy is not None
|
||||
and self.tokens_per_minute == "auto"
|
||||
):
|
||||
msg = f"tokens_per_minute cannot be set to 'auto' when using type '{self.type}'. Please set it to a positive integer or null to disable."
|
||||
raise ValueError(msg)
|
||||
|
||||
requests_per_minute: int | Literal["auto"] | None = Field(
|
||||
description="The number of requests per minute to use for the LLM service.",
|
||||
default=language_model_defaults.requests_per_minute,
|
||||
@ -230,6 +280,19 @@ class LanguageModelConfig(BaseModel):
|
||||
msg = f"Requests per minute must be a non zero positive number, 'auto' or null. Suggested value: {language_model_defaults.requests_per_minute}."
|
||||
raise ValueError(msg)
|
||||
|
||||
if (
|
||||
(self.type == ModelType.Chat or self.type == ModelType.Embedding)
|
||||
and self.rate_limit_strategy is not None
|
||||
and self.requests_per_minute == "auto"
|
||||
):
|
||||
msg = f"requests_per_minute cannot be set to 'auto' when using type '{self.type}'. Please set it to a positive integer or null to disable."
|
||||
raise ValueError(msg)
|
||||
|
||||
rate_limit_strategy: str | None = Field(
|
||||
description="The rate limit strategy to use for the LLM service.",
|
||||
default=language_model_defaults.rate_limit_strategy,
|
||||
)
|
||||
|
||||
retry_strategy: str = Field(
|
||||
description="The retry strategy to use for the LLM service.",
|
||||
default=language_model_defaults.retry_strategy,
|
||||
@ -318,6 +381,7 @@ class LanguageModelConfig(BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def _validate_model(self):
|
||||
self._validate_type()
|
||||
self._validate_model_provider()
|
||||
self._validate_auth_type()
|
||||
self._validate_api_key()
|
||||
self._validate_tokens_per_minute()
|
||||
|
||||
4
graphrag/factory/__init__.py
Normal file
4
graphrag/factory/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Factory module."""
|
||||
68
graphrag/factory/factory.py
Normal file
68
graphrag/factory/factory.py
Normal file
@ -0,0 +1,68 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Factory ABC."""
|
||||
|
||||
from abc import ABC
|
||||
from collections.abc import Callable
|
||||
from typing import Any, ClassVar, Generic, TypeVar
|
||||
|
||||
T = TypeVar("T", covariant=True)
|
||||
|
||||
|
||||
class Factory(ABC, Generic[T]):
|
||||
"""Abstract base class for factories."""
|
||||
|
||||
_instance: ClassVar["Factory | None"] = None
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> "Factory":
|
||||
"""Create a new instance of Factory if it does not exist."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls, *args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, "_initialized"):
|
||||
self._services: dict[str, Callable[..., T]] = {}
|
||||
self._initialized = True
|
||||
|
||||
def __contains__(self, strategy: str) -> bool:
|
||||
"""Check if a strategy is registered."""
|
||||
return strategy in self._services
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
"""Get a list of registered strategy names."""
|
||||
return list(self._services.keys())
|
||||
|
||||
def register(self, *, strategy: str, service_initializer: Callable[..., T]) -> None:
|
||||
"""
|
||||
Register a new service.
|
||||
|
||||
Args
|
||||
----
|
||||
strategy: The name of the strategy.
|
||||
service_initializer: A callable that creates an instance of T.
|
||||
"""
|
||||
self._services[strategy] = service_initializer
|
||||
|
||||
def create(self, *, strategy: str, **kwargs: Any) -> T:
|
||||
"""
|
||||
Create a service instance based on the strategy.
|
||||
|
||||
Args
|
||||
----
|
||||
strategy: The name of the strategy.
|
||||
**kwargs: Additional arguments to pass to the service initializer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
An instance of T.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError: If the strategy is not registered.
|
||||
"""
|
||||
if strategy not in self._services:
|
||||
msg = f"Strategy '{strategy}' is not registered."
|
||||
raise ValueError(msg)
|
||||
return self._services[strategy](**kwargs)
|
||||
@ -11,7 +11,7 @@ import tiktoken
|
||||
from graphrag.config.models.chunking_config import ChunkingConfig
|
||||
from graphrag.index.operations.chunk_text.typing import TextChunk
|
||||
from graphrag.index.text_splitting.text_splitting import (
|
||||
Tokenizer,
|
||||
TokenChunkerOptions,
|
||||
split_multiple_texts_on_tokens,
|
||||
)
|
||||
from graphrag.logger.progress import ProgressTicker
|
||||
@ -45,7 +45,7 @@ def run_tokens(
|
||||
encode, decode = get_encoding_fn(encoding_name)
|
||||
return split_multiple_texts_on_tokens(
|
||||
input,
|
||||
Tokenizer(
|
||||
TokenChunkerOptions(
|
||||
chunk_overlap=chunk_overlap,
|
||||
tokens_per_chunk=tokens_per_chunk,
|
||||
encode=encode,
|
||||
|
||||
@ -18,6 +18,7 @@ from graphrag.index.utils.is_null import is_null
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.language_model.protocol.base import EmbeddingModel
|
||||
from graphrag.logger.progress import ProgressTicker, progress_ticker
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -79,7 +80,7 @@ def _get_splitter(
|
||||
config: LanguageModelConfig, batch_max_tokens: int
|
||||
) -> TokenTextSplitter:
|
||||
return TokenTextSplitter(
|
||||
encoding_name=config.encoding_model,
|
||||
tokenizer=get_tokenizer(model_config=config),
|
||||
chunk_size=batch_max_tokens,
|
||||
)
|
||||
|
||||
|
||||
@ -5,16 +5,16 @@
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Collection, Iterable
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.index.operations.chunk_text.typing import TextChunk
|
||||
from graphrag.logger.progress import ProgressTicker
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
EncodedText = list[int]
|
||||
DecodeFn = Callable[[EncodedText], str]
|
||||
@ -25,8 +25,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Tokenizer:
|
||||
"""Tokenizer data class."""
|
||||
class TokenChunkerOptions:
|
||||
"""TokenChunkerOptions data class."""
|
||||
|
||||
chunk_overlap: int
|
||||
"""Overlap in tokens between chunks"""
|
||||
@ -83,44 +83,18 @@ class NoopTextSplitter(TextSplitter):
|
||||
class TokenTextSplitter(TextSplitter):
|
||||
"""Token text splitter class definition."""
|
||||
|
||||
_allowed_special: Literal["all"] | set[str]
|
||||
_disallowed_special: Literal["all"] | Collection[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoding_name: str = defs.ENCODING_MODEL,
|
||||
model_name: str | None = None,
|
||||
allowed_special: Literal["all"] | set[str] | None = None,
|
||||
disallowed_special: Literal["all"] | Collection[str] = "all",
|
||||
tokenizer: Tokenizer | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Init method definition."""
|
||||
super().__init__(**kwargs)
|
||||
if model_name is not None:
|
||||
try:
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
logger.exception(
|
||||
"Model %s not found, using %s", model_name, encoding_name
|
||||
)
|
||||
enc = tiktoken.get_encoding(encoding_name)
|
||||
else:
|
||||
enc = tiktoken.get_encoding(encoding_name)
|
||||
self._tokenizer = enc
|
||||
self._allowed_special = allowed_special or set()
|
||||
self._disallowed_special = disallowed_special
|
||||
|
||||
def encode(self, text: str) -> list[int]:
|
||||
"""Encode the given text into an int-vector."""
|
||||
return self._tokenizer.encode(
|
||||
text,
|
||||
allowed_special=self._allowed_special,
|
||||
disallowed_special=self._disallowed_special,
|
||||
)
|
||||
self._tokenizer = tokenizer or get_tokenizer()
|
||||
|
||||
def num_tokens(self, text: str) -> int:
|
||||
"""Return the number of tokens in a string."""
|
||||
return len(self.encode(text))
|
||||
return len(self._tokenizer.encode(text))
|
||||
|
||||
def split_text(self, text: str | list[str]) -> list[str]:
|
||||
"""Split text method."""
|
||||
@ -132,17 +106,17 @@ class TokenTextSplitter(TextSplitter):
|
||||
msg = f"Attempting to split a non-string value, actual is {type(text)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
tokenizer = Tokenizer(
|
||||
token_chunker_options = TokenChunkerOptions(
|
||||
chunk_overlap=self._chunk_overlap,
|
||||
tokens_per_chunk=self._chunk_size,
|
||||
decode=self._tokenizer.decode,
|
||||
encode=lambda text: self.encode(text),
|
||||
encode=self._tokenizer.encode,
|
||||
)
|
||||
|
||||
return split_single_text_on_tokens(text=text, tokenizer=tokenizer)
|
||||
return split_single_text_on_tokens(text=text, tokenizer=token_chunker_options)
|
||||
|
||||
|
||||
def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]:
|
||||
def split_single_text_on_tokens(text: str, tokenizer: TokenChunkerOptions) -> list[str]:
|
||||
"""Split a single text and return chunks using the tokenizer."""
|
||||
result = []
|
||||
input_ids = tokenizer.encode(text)
|
||||
@ -166,7 +140,7 @@ def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]:
|
||||
# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471
|
||||
# So we could have better control over the chunking process
|
||||
def split_multiple_texts_on_tokens(
|
||||
texts: list[str], tokenizer: Tokenizer, tick: ProgressTicker
|
||||
texts: list[str], tokenizer: TokenChunkerOptions, tick: ProgressTicker
|
||||
) -> list[TextChunk]:
|
||||
"""Split multiple texts and return chunks with metadata using the tokenizer."""
|
||||
result = []
|
||||
|
||||
@ -14,6 +14,10 @@ from graphrag.language_model.providers.fnllm.models import (
|
||||
OpenAIChatFNLLM,
|
||||
OpenAIEmbeddingFNLLM,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.chat_model import LitellmChatModel
|
||||
from graphrag.language_model.providers.litellm.embedding_model import (
|
||||
LitellmEmbeddingModel,
|
||||
)
|
||||
|
||||
|
||||
class ModelFactory:
|
||||
@ -105,6 +109,7 @@ ModelFactory.register_chat(
|
||||
ModelFactory.register_chat(
|
||||
ModelType.OpenAIChat.value, lambda **kwargs: OpenAIChatFNLLM(**kwargs)
|
||||
)
|
||||
ModelFactory.register_chat(ModelType.Chat, lambda **kwargs: LitellmChatModel(**kwargs))
|
||||
|
||||
ModelFactory.register_embedding(
|
||||
ModelType.AzureOpenAIEmbedding.value,
|
||||
@ -113,3 +118,6 @@ ModelFactory.register_embedding(
|
||||
ModelFactory.register_embedding(
|
||||
ModelType.OpenAIEmbedding.value, lambda **kwargs: OpenAIEmbeddingFNLLM(**kwargs)
|
||||
)
|
||||
ModelFactory.register_embedding(
|
||||
ModelType.Embedding, lambda **kwargs: LitellmEmbeddingModel(**kwargs)
|
||||
)
|
||||
|
||||
4
graphrag/language_model/providers/litellm/__init__.py
Normal file
4
graphrag/language_model/providers/litellm/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""GraphRAG LiteLLM module. Provides LiteLLM-based implementations of chat and embedding models."""
|
||||
413
graphrag/language_model/providers/litellm/chat_model.py
Normal file
413
graphrag/language_model/providers/litellm/chat_model.py
Normal file
@ -0,0 +1,413 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Chat model implementation using Litellm."""
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import litellm
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
from litellm import (
|
||||
CustomStreamWrapper,
|
||||
ModelResponse, # type: ignore
|
||||
acompletion,
|
||||
completion,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphrag.config.defaults import COGNITIVE_SERVICES_AUDIENCE
|
||||
from graphrag.config.enums import AuthType
|
||||
from graphrag.language_model.providers.litellm.request_wrappers.with_cache import (
|
||||
with_cache,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.request_wrappers.with_logging import (
|
||||
with_logging,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.request_wrappers.with_rate_limiter import (
|
||||
with_rate_limiter,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.request_wrappers.with_retries import (
|
||||
with_retries,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.types import (
|
||||
AFixedModelCompletion,
|
||||
FixedModelCompletion,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.language_model.response.base import ModelResponse as MR # noqa: N817
|
||||
|
||||
litellm.suppress_debug_info = True
|
||||
|
||||
|
||||
def _create_base_completions(
|
||||
model_config: "LanguageModelConfig",
|
||||
) -> tuple[FixedModelCompletion, AFixedModelCompletion]:
|
||||
"""Wrap the base litellm completion function with the model configuration.
|
||||
|
||||
Args
|
||||
----
|
||||
model_config: The configuration for the language model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A tuple containing the synchronous and asynchronous completion functions.
|
||||
"""
|
||||
model_provider = model_config.model_provider
|
||||
model = model_config.deployment_name or model_config.model
|
||||
|
||||
base_args: dict[str, Any] = {
|
||||
"drop_params": True, # LiteLLM drop unsupported params for selected model.
|
||||
"model": f"{model_provider}/{model}",
|
||||
"timeout": model_config.request_timeout,
|
||||
"top_p": model_config.top_p,
|
||||
"n": model_config.n,
|
||||
"temperature": model_config.temperature,
|
||||
"frequency_penalty": model_config.frequency_penalty,
|
||||
"presence_penalty": model_config.presence_penalty,
|
||||
"api_base": model_config.api_base,
|
||||
"api_version": model_config.api_version,
|
||||
"api_key": model_config.api_key,
|
||||
"organization": model_config.organization,
|
||||
"proxy": model_config.proxy,
|
||||
"audience": model_config.audience,
|
||||
"max_tokens": model_config.max_tokens,
|
||||
"max_completion_tokens": model_config.max_completion_tokens,
|
||||
"reasoning_effort": model_config.reasoning_effort,
|
||||
}
|
||||
|
||||
if model_config.auth_type == AuthType.AzureManagedIdentity:
|
||||
if model_config.model_provider != "azure":
|
||||
msg = "Azure Managed Identity authentication is only supported for Azure models."
|
||||
raise ValueError(msg)
|
||||
|
||||
base_args["azure_ad_token_provider"] = get_bearer_token_provider(
|
||||
DefaultAzureCredential(),
|
||||
COGNITIVE_SERVICES_AUDIENCE,
|
||||
)
|
||||
|
||||
def _base_completion(**kwargs: Any) -> ModelResponse | CustomStreamWrapper:
|
||||
new_args = {**base_args, **kwargs}
|
||||
|
||||
if "name" in new_args:
|
||||
new_args.pop("name")
|
||||
|
||||
return completion(**new_args)
|
||||
|
||||
async def _base_acompletion(**kwargs: Any) -> ModelResponse | CustomStreamWrapper:
|
||||
new_args = {**base_args, **kwargs}
|
||||
|
||||
if "name" in new_args:
|
||||
new_args.pop("name")
|
||||
|
||||
return await acompletion(**new_args)
|
||||
|
||||
return (_base_completion, _base_acompletion)
|
||||
|
||||
|
||||
def _create_completions(
|
||||
model_config: "LanguageModelConfig",
|
||||
cache: "PipelineCache | None",
|
||||
cache_key_prefix: str,
|
||||
) -> tuple[FixedModelCompletion, AFixedModelCompletion]:
|
||||
"""Wrap the base litellm completion function with the model configuration and additional features.
|
||||
|
||||
Wrap the base litellm completion function with instance variables based on the model configuration.
|
||||
Then wrap additional features such as rate limiting, retries, and caching, if enabled.
|
||||
|
||||
Final function composition order:
|
||||
- Logging(Cache(Retries(RateLimiter(ModelCompletion()))))
|
||||
|
||||
Args
|
||||
----
|
||||
model_config: The configuration for the language model.
|
||||
cache: Optional cache for storing responses.
|
||||
cache_key_prefix: Prefix for cache keys.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A tuple containing the synchronous and asynchronous completion functions.
|
||||
|
||||
"""
|
||||
completion, acompletion = _create_base_completions(model_config)
|
||||
|
||||
# TODO: For v2.x release, rpm/tpm can be int or str (auto) for backwards compatibility with fnllm.
|
||||
# LiteLLM does not support "auto", so we have to check those values here.
|
||||
# For v3 release, force rpm/tpm to be int and remove the type checks below
|
||||
# and just check if rate_limit_strategy is enabled.
|
||||
if model_config.rate_limit_strategy is not None:
|
||||
rpm = (
|
||||
model_config.requests_per_minute
|
||||
if type(model_config.requests_per_minute) is int
|
||||
else None
|
||||
)
|
||||
tpm = (
|
||||
model_config.tokens_per_minute
|
||||
if type(model_config.tokens_per_minute) is int
|
||||
else None
|
||||
)
|
||||
if rpm is not None or tpm is not None:
|
||||
completion, acompletion = with_rate_limiter(
|
||||
sync_fn=completion,
|
||||
async_fn=acompletion,
|
||||
model_config=model_config,
|
||||
rpm=rpm,
|
||||
tpm=tpm,
|
||||
)
|
||||
|
||||
if model_config.retry_strategy != "none":
|
||||
completion, acompletion = with_retries(
|
||||
sync_fn=completion,
|
||||
async_fn=acompletion,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
completion, acompletion = with_cache(
|
||||
sync_fn=completion,
|
||||
async_fn=acompletion,
|
||||
model_config=model_config,
|
||||
cache=cache,
|
||||
request_type="chat",
|
||||
cache_key_prefix=cache_key_prefix,
|
||||
)
|
||||
|
||||
completion, acompletion = with_logging(
|
||||
sync_fn=completion,
|
||||
async_fn=acompletion,
|
||||
)
|
||||
|
||||
return (completion, acompletion)
|
||||
|
||||
|
||||
class LitellmModelOutput(BaseModel):
|
||||
"""A model representing the output from a language model."""
|
||||
|
||||
content: str = Field(description="The generated text content")
|
||||
full_response: None = Field(
|
||||
default=None, description="The full response from the model, if available"
|
||||
)
|
||||
|
||||
|
||||
class LitellmModelResponse(BaseModel):
|
||||
"""A model representing the response from a language model."""
|
||||
|
||||
output: LitellmModelOutput = Field(description="The output from the model")
|
||||
parsed_response: BaseModel | None = Field(
|
||||
default=None, description="Parsed response from the model"
|
||||
)
|
||||
history: list = Field(
|
||||
default_factory=list,
|
||||
description="Conversation history including the prompt and response",
|
||||
)
|
||||
|
||||
|
||||
class LitellmChatModel:
|
||||
"""LiteLLM-based Chat Model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
config: "LanguageModelConfig",
|
||||
cache: "PipelineCache | None" = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.name = name
|
||||
self.config = config
|
||||
self.cache = cache.child(self.name) if cache else None
|
||||
self.completion, self.acompletion = _create_completions(
|
||||
config, self.cache, "chat"
|
||||
)
|
||||
|
||||
def _get_kwargs(self, **kwargs: Any) -> dict[str, Any]:
|
||||
"""Get model arguments supported by litellm."""
|
||||
args_to_include = [
|
||||
"name",
|
||||
"modalities",
|
||||
"prediction",
|
||||
"audio",
|
||||
"logit_bias",
|
||||
"metadata",
|
||||
"user",
|
||||
"response_format",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"parallel_tool_calls",
|
||||
"web_search_options",
|
||||
"extra_headers",
|
||||
"functions",
|
||||
"function_call",
|
||||
"thinking",
|
||||
]
|
||||
new_args = {k: v for k, v in kwargs.items() if k in args_to_include}
|
||||
|
||||
# If using JSON, check if response_format should be a Pydantic model or just a general JSON object
|
||||
if kwargs.get("json"):
|
||||
new_args["response_format"] = {"type": "json_object"}
|
||||
|
||||
if (
|
||||
"json_model" in kwargs
|
||||
and inspect.isclass(kwargs["json_model"])
|
||||
and issubclass(kwargs["json_model"], BaseModel)
|
||||
):
|
||||
new_args["response_format"] = kwargs["json_model"]
|
||||
|
||||
return new_args
|
||||
|
||||
async def achat(
|
||||
self, prompt: str, history: list | None = None, **kwargs: Any
|
||||
) -> "MR":
|
||||
"""
|
||||
Generate a response for the given prompt and history.
|
||||
|
||||
Args
|
||||
----
|
||||
prompt: The prompt to generate a response for.
|
||||
history: Optional chat history.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
LitellmModelResponse: The generated model response.
|
||||
"""
|
||||
new_kwargs = self._get_kwargs(**kwargs)
|
||||
messages: list[dict[str, str]] = history or []
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = await self.acompletion(messages=messages, stream=False, **new_kwargs) # type: ignore
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": response.choices[0].message.content or "", # type: ignore
|
||||
})
|
||||
|
||||
parsed_response: BaseModel | None = None
|
||||
if "response_format" in new_kwargs:
|
||||
parsed_dict: dict[str, Any] = json.loads(
|
||||
response.choices[0].message.content or "{}" # type: ignore
|
||||
)
|
||||
parsed_response = parsed_dict # type: ignore
|
||||
if inspect.isclass(new_kwargs["response_format"]) and issubclass(
|
||||
new_kwargs["response_format"], BaseModel
|
||||
):
|
||||
# If response_format is a pydantic model, instantiate it
|
||||
model_initializer = cast(
|
||||
"type[BaseModel]", new_kwargs["response_format"]
|
||||
)
|
||||
parsed_response = model_initializer(**parsed_dict)
|
||||
|
||||
return LitellmModelResponse(
|
||||
output=LitellmModelOutput(
|
||||
content=response.choices[0].message.content or "" # type: ignore
|
||||
),
|
||||
parsed_response=parsed_response,
|
||||
history=messages,
|
||||
)
|
||||
|
||||
async def achat_stream(
|
||||
self, prompt: str, history: list | None = None, **kwargs: Any
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Generate a response for the given prompt and history.
|
||||
|
||||
Args
|
||||
----
|
||||
prompt: The prompt to generate a response for.
|
||||
history: Optional chat history.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
AsyncGenerator[str, None]: The generated response as a stream of strings.
|
||||
"""
|
||||
new_kwargs = self._get_kwargs(**kwargs)
|
||||
messages: list[dict[str, str]] = history or []
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = await self.acompletion(messages=messages, stream=True, **new_kwargs) # type: ignore
|
||||
|
||||
async for chunk in response: # type: ignore
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
def chat(self, prompt: str, history: list | None = None, **kwargs: Any) -> "MR":
|
||||
"""
|
||||
Generate a response for the given prompt and history.
|
||||
|
||||
Args
|
||||
----
|
||||
prompt: The prompt to generate a response for.
|
||||
history: Optional chat history.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
LitellmModelResponse: The generated model response.
|
||||
"""
|
||||
new_kwargs = self._get_kwargs(**kwargs)
|
||||
messages: list[dict[str, str]] = history or []
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = self.completion(messages=messages, stream=False, **new_kwargs) # type: ignore
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": response.choices[0].message.content or "", # type: ignore
|
||||
})
|
||||
|
||||
parsed_response: BaseModel | None = None
|
||||
if "response_format" in new_kwargs:
|
||||
parsed_dict: dict[str, Any] = json.loads(
|
||||
response.choices[0].message.content or "{}" # type: ignore
|
||||
)
|
||||
parsed_response = parsed_dict # type: ignore
|
||||
if inspect.isclass(new_kwargs["response_format"]) and issubclass(
|
||||
new_kwargs["response_format"], BaseModel
|
||||
):
|
||||
# If response_format is a pydantic model, instantiate it
|
||||
model_initializer = cast(
|
||||
"type[BaseModel]", new_kwargs["response_format"]
|
||||
)
|
||||
parsed_response = model_initializer(**parsed_dict)
|
||||
|
||||
return LitellmModelResponse(
|
||||
output=LitellmModelOutput(
|
||||
content=response.choices[0].message.content or "" # type: ignore
|
||||
),
|
||||
parsed_response=parsed_response,
|
||||
history=messages,
|
||||
)
|
||||
|
||||
def chat_stream(
|
||||
self, prompt: str, history: list | None = None, **kwargs: Any
|
||||
) -> Generator[str, None]:
|
||||
"""
|
||||
Generate a response for the given prompt and history.
|
||||
|
||||
Args
|
||||
----
|
||||
prompt: The prompt to generate a response for.
|
||||
history: Optional chat history.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Generator[str, None]: The generated response as a stream of strings.
|
||||
"""
|
||||
new_kwargs = self._get_kwargs(**kwargs)
|
||||
messages: list[dict[str, str]] = history or []
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = self.completion(messages=messages, stream=True, **new_kwargs) # type: ignore
|
||||
|
||||
for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content: # type: ignore
|
||||
yield chunk.choices[0].delta.content # type: ignore
|
||||
279
graphrag/language_model/providers/litellm/embedding_model.py
Normal file
279
graphrag/language_model/providers/litellm/embedding_model.py
Normal file
@ -0,0 +1,279 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Embedding model implementation using Litellm."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import litellm
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
from litellm import (
|
||||
EmbeddingResponse, # type: ignore
|
||||
aembedding,
|
||||
embedding,
|
||||
)
|
||||
|
||||
from graphrag.config.defaults import COGNITIVE_SERVICES_AUDIENCE
|
||||
from graphrag.config.enums import AuthType
|
||||
from graphrag.language_model.providers.litellm.request_wrappers.with_cache import (
|
||||
with_cache,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.request_wrappers.with_logging import (
|
||||
with_logging,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.request_wrappers.with_rate_limiter import (
|
||||
with_rate_limiter,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.request_wrappers.with_retries import (
|
||||
with_retries,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.types import (
|
||||
AFixedModelEmbedding,
|
||||
FixedModelEmbedding,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
|
||||
litellm.suppress_debug_info = True
|
||||
|
||||
|
||||
def _create_base_embeddings(
|
||||
model_config: "LanguageModelConfig",
|
||||
) -> tuple[FixedModelEmbedding, AFixedModelEmbedding]:
|
||||
"""Wrap the base litellm embedding function with the model configuration.
|
||||
|
||||
Args
|
||||
----
|
||||
model_config: The configuration for the language model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A tuple containing the synchronous and asynchronous embedding functions.
|
||||
"""
|
||||
model_provider = model_config.model_provider
|
||||
model = model_config.deployment_name or model_config.model
|
||||
|
||||
base_args: dict[str, Any] = {
|
||||
"drop_params": True, # LiteLLM drop unsupported params for selected model.
|
||||
"model": f"{model_provider}/{model}",
|
||||
"timeout": model_config.request_timeout,
|
||||
"api_base": model_config.api_base,
|
||||
"api_version": model_config.api_version,
|
||||
"api_key": model_config.api_key,
|
||||
"organization": model_config.organization,
|
||||
"proxy": model_config.proxy,
|
||||
"audience": model_config.audience,
|
||||
}
|
||||
|
||||
if model_config.auth_type == AuthType.AzureManagedIdentity:
|
||||
if model_config.model_provider != "azure":
|
||||
msg = "Azure Managed Identity authentication is only supported for Azure models."
|
||||
raise ValueError(msg)
|
||||
|
||||
base_args["azure_ad_token_provider"] = get_bearer_token_provider(
|
||||
DefaultAzureCredential(),
|
||||
COGNITIVE_SERVICES_AUDIENCE,
|
||||
)
|
||||
|
||||
def _base_embedding(**kwargs: Any) -> EmbeddingResponse:
|
||||
new_args = {**base_args, **kwargs}
|
||||
|
||||
if "name" in new_args:
|
||||
new_args.pop("name")
|
||||
|
||||
return embedding(**new_args)
|
||||
|
||||
async def _base_aembedding(**kwargs: Any) -> EmbeddingResponse:
|
||||
new_args = {**base_args, **kwargs}
|
||||
|
||||
if "name" in new_args:
|
||||
new_args.pop("name")
|
||||
|
||||
return await aembedding(**new_args)
|
||||
|
||||
return (_base_embedding, _base_aembedding)
|
||||
|
||||
|
||||
def _create_embeddings(
|
||||
model_config: "LanguageModelConfig",
|
||||
cache: "PipelineCache | None",
|
||||
cache_key_prefix: str,
|
||||
) -> tuple[FixedModelEmbedding, AFixedModelEmbedding]:
|
||||
"""Wrap the base litellm embedding function with the model configuration and additional features.
|
||||
|
||||
Wrap the base litellm embedding function with instance variables based on the model configuration.
|
||||
Then wrap additional features such as rate limiting, retries, and caching, if enabled.
|
||||
|
||||
Final function composition order:
|
||||
- Logging(Cache(Retries(RateLimiter(ModelEmbedding()))))
|
||||
|
||||
Args
|
||||
----
|
||||
model_config: The configuration for the language model.
|
||||
cache: Optional cache for storing responses.
|
||||
cache_key_prefix: Prefix for cache keys.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A tuple containing the synchronous and asynchronous embedding functions.
|
||||
|
||||
"""
|
||||
embedding, aembedding = _create_base_embeddings(model_config)
|
||||
|
||||
# TODO: For v2.x release, rpm/tpm can be int or str (auto) for backwards compatibility with fnllm.
|
||||
# LiteLLM does not support "auto", so we have to check those values here.
|
||||
# For v3 release, force rpm/tpm to be int and remove the type checks below
|
||||
# and just check if rate_limit_strategy is enabled.
|
||||
if model_config.rate_limit_strategy is not None:
|
||||
rpm = (
|
||||
model_config.requests_per_minute
|
||||
if type(model_config.requests_per_minute) is int
|
||||
else None
|
||||
)
|
||||
tpm = (
|
||||
model_config.tokens_per_minute
|
||||
if type(model_config.tokens_per_minute) is int
|
||||
else None
|
||||
)
|
||||
if rpm is not None or tpm is not None:
|
||||
embedding, aembedding = with_rate_limiter(
|
||||
sync_fn=embedding,
|
||||
async_fn=aembedding,
|
||||
model_config=model_config,
|
||||
rpm=rpm,
|
||||
tpm=tpm,
|
||||
)
|
||||
|
||||
if model_config.retry_strategy != "none":
|
||||
embedding, aembedding = with_retries(
|
||||
sync_fn=embedding,
|
||||
async_fn=aembedding,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
embedding, aembedding = with_cache(
|
||||
sync_fn=embedding,
|
||||
async_fn=aembedding,
|
||||
model_config=model_config,
|
||||
cache=cache,
|
||||
request_type="embedding",
|
||||
cache_key_prefix=cache_key_prefix,
|
||||
)
|
||||
|
||||
embedding, aembedding = with_logging(
|
||||
sync_fn=embedding,
|
||||
async_fn=aembedding,
|
||||
)
|
||||
|
||||
return (embedding, aembedding)
|
||||
|
||||
|
||||
class LitellmEmbeddingModel:
|
||||
"""LiteLLM-based Embedding Model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
config: "LanguageModelConfig",
|
||||
cache: "PipelineCache | None" = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.name = name
|
||||
self.config = config
|
||||
self.cache = cache.child(self.name) if cache else None
|
||||
self.embedding, self.aembedding = _create_embeddings(
|
||||
config, self.cache, "embeddings"
|
||||
)
|
||||
|
||||
def _get_kwargs(self, **kwargs: Any) -> dict[str, Any]:
|
||||
"""Get model arguments supported by litellm."""
|
||||
args_to_include = [
|
||||
"name",
|
||||
"dimensions",
|
||||
"encoding_format",
|
||||
"timeout",
|
||||
"user",
|
||||
]
|
||||
return {k: v for k, v in kwargs.items() if k in args_to_include}
|
||||
|
||||
async def aembed_batch(
|
||||
self, text_list: list[str], **kwargs: Any
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Batch generate embeddings.
|
||||
|
||||
Args
|
||||
----
|
||||
text_list: A batch of text inputs to generate embeddings for.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
A Batch of embeddings.
|
||||
"""
|
||||
new_kwargs = self._get_kwargs(**kwargs)
|
||||
response = await self.aembedding(input=text_list, **new_kwargs)
|
||||
|
||||
return [emb.get("embedding", []) for emb in response.data]
|
||||
|
||||
async def aembed(self, text: str, **kwargs: Any) -> list[float]:
|
||||
"""
|
||||
Async embed.
|
||||
|
||||
Args:
|
||||
text: The text to generate an embedding for.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
An embedding.
|
||||
"""
|
||||
new_kwargs = self._get_kwargs(**kwargs)
|
||||
response = await self.aembedding(input=[text], **new_kwargs)
|
||||
|
||||
return (
|
||||
response.data[0].get("embedding", [])
|
||||
if response.data and response.data[0]
|
||||
else []
|
||||
)
|
||||
|
||||
def embed_batch(self, text_list: list[str], **kwargs: Any) -> list[list[float]]:
|
||||
"""
|
||||
Batch generate embeddings.
|
||||
|
||||
Args:
|
||||
text_list: A batch of text inputs to generate embeddings for.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
A Batch of embeddings.
|
||||
"""
|
||||
new_kwargs = self._get_kwargs(**kwargs)
|
||||
response = self.embedding(input=text_list, **new_kwargs)
|
||||
|
||||
return [emb.get("embedding", []) for emb in response.data]
|
||||
|
||||
def embed(self, text: str, **kwargs: Any) -> list[float]:
|
||||
"""
|
||||
Embed a single text input.
|
||||
|
||||
Args:
|
||||
text: The text to generate an embedding for.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
An embedding.
|
||||
"""
|
||||
new_kwargs = self._get_kwargs(**kwargs)
|
||||
response = self.embedding(input=[text], **new_kwargs)
|
||||
|
||||
return (
|
||||
response.data[0].get("embedding", [])
|
||||
if response.data and response.data[0]
|
||||
else []
|
||||
)
|
||||
140
graphrag/language_model/providers/litellm/get_cache_key.py
Normal file
140
graphrag/language_model/providers/litellm/get_cache_key.py
Normal file
@ -0,0 +1,140 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""
|
||||
LiteLLM cache key generation.
|
||||
|
||||
Modeled after the fnllm cache key generation.
|
||||
https://github.com/microsoft/essex-toolkit/blob/23d3077b65c0e8f1d89c397a2968fe570a25f790/python/fnllm/fnllm/caching/base.py#L50
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
|
||||
|
||||
_CACHE_VERSION = 3
|
||||
"""
|
||||
If there's a breaking change in what we cache, we should increment this version number to invalidate existing caches.
|
||||
|
||||
fnllm was on cache version 2 and though we generate
|
||||
similar cache keys, the objects stored in cache by fnllm and litellm are different.
|
||||
Using litellm model providers will not be able to reuse caches generated by fnllm
|
||||
thus we start with version 3 for litellm.
|
||||
"""
|
||||
|
||||
|
||||
def get_cache_key(
|
||||
model_config: "LanguageModelConfig",
|
||||
prefix: str,
|
||||
messages: str | None = None,
|
||||
input: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Generate a cache key based on the model configuration and input arguments.
|
||||
|
||||
Modeled after the fnllm cache key generation.
|
||||
https://github.com/microsoft/essex-toolkit/blob/23d3077b65c0e8f1d89c397a2968fe570a25f790/python/fnllm/fnllm/caching/base.py#L50
|
||||
|
||||
Args
|
||||
____
|
||||
model_config: The configuration of the language model.
|
||||
prefix: A prefix for the cache key.
|
||||
**kwargs: Additional model input parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
`{prefix}_{data_hash}_v{version}` if prefix is provided.
|
||||
"""
|
||||
cache_key: dict[str, Any] = {
|
||||
"parameters": _get_parameters(model_config, **kwargs),
|
||||
}
|
||||
|
||||
if messages is not None and input is not None:
|
||||
msg = "Only one of 'messages' or 'input' should be provided."
|
||||
raise ValueError(msg)
|
||||
|
||||
if messages is not None:
|
||||
cache_key["messages"] = messages
|
||||
elif input is not None:
|
||||
cache_key["input"] = input
|
||||
else:
|
||||
msg = "Either 'messages' or 'input' must be provided."
|
||||
raise ValueError(msg)
|
||||
|
||||
data_hash = _hash(json.dumps(cache_key, sort_keys=True))
|
||||
|
||||
name = kwargs.get("name")
|
||||
|
||||
if name:
|
||||
prefix += f"_{name}"
|
||||
|
||||
return f"{prefix}_{data_hash}_v{_CACHE_VERSION}"
|
||||
|
||||
|
||||
def _get_parameters(
|
||||
model_config: "LanguageModelConfig",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Pluck out the parameters that define a cache key.
|
||||
|
||||
Use the same parameters as fnllm except request timeout.
|
||||
- embeddings: https://github.com/microsoft/essex-toolkit/blob/main/python/fnllm/fnllm/openai/types/embeddings/parameters.py#L12
|
||||
- chat: https://github.com/microsoft/essex-toolkit/blob/main/python/fnllm/fnllm/openai/types/chat/parameters.py#L25
|
||||
|
||||
Args
|
||||
____
|
||||
model_config: The configuration of the language model.
|
||||
**kwargs: Additional model input parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, Any]: A dictionary of parameters that define the cache key.
|
||||
"""
|
||||
parameters = {
|
||||
"model": model_config.deployment_name or model_config.model,
|
||||
"frequency_penalty": model_config.frequency_penalty,
|
||||
"max_tokens": model_config.max_tokens,
|
||||
"max_completion_tokens": model_config.max_completion_tokens,
|
||||
"n": model_config.n,
|
||||
"presence_penalty": model_config.presence_penalty,
|
||||
"temperature": model_config.temperature,
|
||||
"top_p": model_config.top_p,
|
||||
"reasoning_effort": model_config.reasoning_effort,
|
||||
}
|
||||
keys_to_cache = [
|
||||
"function_call",
|
||||
"functions",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"parallel_tool_calls",
|
||||
"seed",
|
||||
"service_tier",
|
||||
"stop",
|
||||
"tool_choice",
|
||||
"tools",
|
||||
"top_logprobs",
|
||||
"user",
|
||||
"dimensions",
|
||||
"encoding_format",
|
||||
]
|
||||
parameters.update({key: kwargs.get(key) for key in keys_to_cache if key in kwargs})
|
||||
|
||||
response_format = kwargs.get("response_format")
|
||||
if inspect.isclass(response_format) and issubclass(response_format, BaseModel):
|
||||
parameters["response_format"] = str(response_format)
|
||||
elif response_format is not None:
|
||||
parameters["response_format"] = response_format
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def _hash(input: str) -> str:
|
||||
"""Generate a hash for the input string."""
|
||||
return hashlib.sha256(input.encode()).hexdigest()
|
||||
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM completion/embedding function wrappers."""
|
||||
@ -0,0 +1,107 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM completion/embedding cache wrapper."""
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from litellm import EmbeddingResponse, ModelResponse # type: ignore
|
||||
|
||||
from graphrag.language_model.providers.litellm.get_cache_key import get_cache_key
|
||||
from graphrag.language_model.providers.litellm.types import (
|
||||
AsyncLitellmRequestFunc,
|
||||
LitellmRequestFunc,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
|
||||
|
||||
def with_cache(
|
||||
*,
|
||||
sync_fn: LitellmRequestFunc,
|
||||
async_fn: AsyncLitellmRequestFunc,
|
||||
model_config: "LanguageModelConfig",
|
||||
cache: "PipelineCache",
|
||||
request_type: Literal["chat", "embedding"],
|
||||
cache_key_prefix: str,
|
||||
) -> tuple[LitellmRequestFunc, AsyncLitellmRequestFunc]:
|
||||
"""
|
||||
Wrap the synchronous and asynchronous request functions with caching.
|
||||
|
||||
Args
|
||||
----
|
||||
sync_fn: The synchronous chat/embedding request function to wrap.
|
||||
async_fn: The asynchronous chat/embedding request function to wrap.
|
||||
model_config: The configuration for the language model.
|
||||
cache: The cache to use for storing responses.
|
||||
request_type: The type of request being made, either "chat" or "embedding".
|
||||
cache_key_prefix: The prefix to use for cache keys.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A tuple containing the wrapped synchronous and asynchronous chat/embedding request functions.
|
||||
"""
|
||||
|
||||
def _wrapped_with_cache(**kwargs: Any) -> Any:
|
||||
is_streaming = kwargs.get("stream", False)
|
||||
if is_streaming:
|
||||
return sync_fn(**kwargs)
|
||||
cache_key = get_cache_key(
|
||||
model_config=model_config, prefix=cache_key_prefix, **kwargs
|
||||
)
|
||||
event_loop = asyncio.get_event_loop()
|
||||
cached_response = event_loop.run_until_complete(cache.get(cache_key))
|
||||
if (
|
||||
cached_response is not None
|
||||
and isinstance(cached_response, dict)
|
||||
and "response" in cached_response
|
||||
and cached_response["response"] is not None
|
||||
and isinstance(cached_response["response"], dict)
|
||||
):
|
||||
try:
|
||||
if request_type == "chat":
|
||||
return ModelResponse(**cached_response["response"])
|
||||
return EmbeddingResponse(**cached_response["response"])
|
||||
except Exception: # noqa: BLE001
|
||||
# Try to retrieve value from cache but if it fails, continue
|
||||
# to make the request.
|
||||
...
|
||||
response = sync_fn(**kwargs)
|
||||
event_loop.run_until_complete(
|
||||
cache.set(cache_key, {"response": response.model_dump()})
|
||||
)
|
||||
return response
|
||||
|
||||
async def _wrapped_with_cache_async(
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
is_streaming = kwargs.get("stream", False)
|
||||
if is_streaming:
|
||||
return await async_fn(**kwargs)
|
||||
cache_key = get_cache_key(
|
||||
model_config=model_config, prefix=cache_key_prefix, **kwargs
|
||||
)
|
||||
cached_response = await cache.get(cache_key)
|
||||
if (
|
||||
cached_response is not None
|
||||
and isinstance(cached_response, dict)
|
||||
and "response" in cached_response
|
||||
and cached_response["response"] is not None
|
||||
and isinstance(cached_response["response"], dict)
|
||||
):
|
||||
try:
|
||||
if request_type == "chat":
|
||||
return ModelResponse(**cached_response["response"])
|
||||
return EmbeddingResponse(**cached_response["response"])
|
||||
except Exception: # noqa: BLE001
|
||||
# Try to retrieve value from cache but if it fails, continue
|
||||
# to make the request.
|
||||
...
|
||||
response = await async_fn(**kwargs)
|
||||
await cache.set(cache_key, {"response": response.model_dump()})
|
||||
return response
|
||||
|
||||
return (_wrapped_with_cache, _wrapped_with_cache_async)
|
||||
@ -0,0 +1,56 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM completion/embedding logging wrapper."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from graphrag.language_model.providers.litellm.types import (
|
||||
AsyncLitellmRequestFunc,
|
||||
LitellmRequestFunc,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def with_logging(
|
||||
*,
|
||||
sync_fn: LitellmRequestFunc,
|
||||
async_fn: AsyncLitellmRequestFunc,
|
||||
) -> tuple[LitellmRequestFunc, AsyncLitellmRequestFunc]:
|
||||
"""
|
||||
Wrap the synchronous and asynchronous request functions with retries.
|
||||
|
||||
Args
|
||||
----
|
||||
sync_fn: The synchronous chat/embedding request function to wrap.
|
||||
async_fn: The asynchronous chat/embedding request function to wrap.
|
||||
model_config: The configuration for the language model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A tuple containing the wrapped synchronous and asynchronous chat/embedding request functions.
|
||||
"""
|
||||
|
||||
def _wrapped_with_logging(**kwargs: Any) -> Any:
|
||||
try:
|
||||
return sync_fn(**kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"with_logging: Request failed with exception={e}", # noqa: G004, TRY401
|
||||
)
|
||||
raise
|
||||
|
||||
async def _wrapped_with_logging_async(
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
try:
|
||||
return await async_fn(**kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"with_logging: Async request failed with exception={e}", # noqa: G004, TRY401
|
||||
)
|
||||
raise
|
||||
|
||||
return (_wrapped_with_logging, _wrapped_with_logging_async)
|
||||
@ -0,0 +1,97 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM completion/embedding rate limiter wrapper."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from litellm import token_counter # type: ignore
|
||||
|
||||
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter_factory import (
|
||||
RateLimiterFactory,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.types import (
|
||||
AsyncLitellmRequestFunc,
|
||||
LitellmRequestFunc,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
|
||||
|
||||
def with_rate_limiter(
|
||||
*,
|
||||
sync_fn: LitellmRequestFunc,
|
||||
async_fn: AsyncLitellmRequestFunc,
|
||||
model_config: "LanguageModelConfig",
|
||||
rpm: int | None = None,
|
||||
tpm: int | None = None,
|
||||
) -> tuple[LitellmRequestFunc, AsyncLitellmRequestFunc]:
|
||||
"""
|
||||
Wrap the synchronous and asynchronous request functions with rate limiting.
|
||||
|
||||
Args
|
||||
----
|
||||
sync_fn: The synchronous chat/embedding request function to wrap.
|
||||
async_fn: The asynchronous chat/embedding request function to wrap.
|
||||
model_config: The configuration for the language model.
|
||||
processing_event: A threading event that can be used to pause the rate limiter.
|
||||
rpm: An optional requests per minute limit.
|
||||
tpm: An optional tokens per minute limit.
|
||||
|
||||
If `rpm` and `tpm` is set to 0 or None, rate limiting is disabled.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A tuple containing the wrapped synchronous and asynchronous chat/embedding request functions.
|
||||
"""
|
||||
rate_limiter_factory = RateLimiterFactory()
|
||||
|
||||
if (
|
||||
model_config.rate_limit_strategy is None
|
||||
or model_config.rate_limit_strategy not in rate_limiter_factory
|
||||
):
|
||||
msg = f"Rate Limiter strategy '{model_config.rate_limit_strategy}' is none or not registered. Available strategies: {', '.join(rate_limiter_factory.keys())}"
|
||||
raise ValueError(msg)
|
||||
|
||||
rate_limiter_service = rate_limiter_factory.create(
|
||||
strategy=model_config.rate_limit_strategy, rpm=rpm, tpm=tpm
|
||||
)
|
||||
|
||||
max_tokens = model_config.max_completion_tokens or model_config.max_tokens or 0
|
||||
|
||||
def _wrapped_with_rate_limiter(**kwargs: Any) -> Any:
|
||||
token_count = max_tokens
|
||||
if "messages" in kwargs:
|
||||
token_count += token_counter(
|
||||
model=model_config.model,
|
||||
messages=kwargs["messages"],
|
||||
)
|
||||
elif "input" in kwargs:
|
||||
token_count += token_counter(
|
||||
model=model_config.model,
|
||||
text=kwargs["input"],
|
||||
)
|
||||
|
||||
with rate_limiter_service.acquire(token_count=token_count):
|
||||
return sync_fn(**kwargs)
|
||||
|
||||
async def _wrapped_with_rate_limiter_async(
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
token_count = max_tokens
|
||||
if "messages" in kwargs:
|
||||
token_count += token_counter(
|
||||
model=model_config.model,
|
||||
messages=kwargs["messages"],
|
||||
)
|
||||
elif "input" in kwargs:
|
||||
token_count += token_counter(
|
||||
model=model_config.model,
|
||||
text=kwargs["input"],
|
||||
)
|
||||
|
||||
with rate_limiter_service.acquire(token_count=token_count):
|
||||
return await async_fn(**kwargs)
|
||||
|
||||
return (_wrapped_with_rate_limiter, _wrapped_with_rate_limiter_async)
|
||||
@ -0,0 +1,54 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM completion/embedding retries wrapper."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from graphrag.language_model.providers.litellm.services.retry.retry_factory import (
|
||||
RetryFactory,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.types import (
|
||||
AsyncLitellmRequestFunc,
|
||||
LitellmRequestFunc,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
|
||||
|
||||
def with_retries(
|
||||
*,
|
||||
sync_fn: LitellmRequestFunc,
|
||||
async_fn: AsyncLitellmRequestFunc,
|
||||
model_config: "LanguageModelConfig",
|
||||
) -> tuple[LitellmRequestFunc, AsyncLitellmRequestFunc]:
|
||||
"""
|
||||
Wrap the synchronous and asynchronous request functions with retries.
|
||||
|
||||
Args
|
||||
----
|
||||
sync_fn: The synchronous chat/embedding request function to wrap.
|
||||
async_fn: The asynchronous chat/embedding request function to wrap.
|
||||
model_config: The configuration for the language model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A tuple containing the wrapped synchronous and asynchronous chat/embedding request functions.
|
||||
"""
|
||||
retry_factory = RetryFactory()
|
||||
retry_service = retry_factory.create(
|
||||
strategy=model_config.retry_strategy,
|
||||
max_attempts=model_config.max_retries,
|
||||
max_retry_wait=model_config.max_retry_wait,
|
||||
)
|
||||
|
||||
def _wrapped_with_retries(**kwargs: Any) -> Any:
|
||||
return retry_service.retry(func=sync_fn, **kwargs)
|
||||
|
||||
async def _wrapped_with_retries_async(
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return await retry_service.aretry(func=async_fn, **kwargs)
|
||||
|
||||
return (_wrapped_with_retries, _wrapped_with_retries_async)
|
||||
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM Services."""
|
||||
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM Rate Limiter."""
|
||||
@ -0,0 +1,37 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM Rate Limiter."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
|
||||
class RateLimiter(ABC):
|
||||
"""Abstract base class for rate limiters."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
/,
|
||||
**kwargs: Any,
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
@contextmanager
|
||||
def acquire(self, *, token_count: int) -> Iterator[None]:
|
||||
"""
|
||||
Acquire Rate Limiter.
|
||||
|
||||
Args
|
||||
----
|
||||
token_count: The estimated number of tokens for the current request.
|
||||
|
||||
Yields
|
||||
------
|
||||
None: This context manager does not return any value.
|
||||
"""
|
||||
msg = "RateLimiter subclasses must implement the acquire method."
|
||||
raise NotImplementedError(msg)
|
||||
@ -0,0 +1,22 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM Rate Limiter Factory."""
|
||||
|
||||
from graphrag.config.defaults import DEFAULT_RATE_LIMITER_SERVICES
|
||||
from graphrag.factory.factory import Factory
|
||||
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter import (
|
||||
RateLimiter,
|
||||
)
|
||||
|
||||
|
||||
class RateLimiterFactory(Factory[RateLimiter]):
|
||||
"""Singleton factory for creating rate limiter services."""
|
||||
|
||||
|
||||
rate_limiter_factory = RateLimiterFactory()
|
||||
|
||||
for service_name, service_cls in DEFAULT_RATE_LIMITER_SERVICES.items():
|
||||
rate_limiter_factory.register(
|
||||
strategy=service_name, service_initializer=service_cls
|
||||
)
|
||||
@ -0,0 +1,133 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM Static Rate Limiter."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter import (
|
||||
RateLimiter,
|
||||
)
|
||||
|
||||
|
||||
class StaticRateLimiter(RateLimiter):
|
||||
"""Static Rate Limiter implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
rpm: int | None = None,
|
||||
tpm: int | None = None,
|
||||
default_stagger: float = 0.0,
|
||||
period_in_seconds: int = 60,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if rpm is None and tpm is None:
|
||||
msg = "Both TPM and RPM cannot be None (disabled), one or both must be set to a positive integer."
|
||||
raise ValueError(msg)
|
||||
if (rpm is not None and rpm <= 0) or (tpm is not None and tpm <= 0):
|
||||
msg = "RPM and TPM must be either None (disabled) or positive integers."
|
||||
raise ValueError(msg)
|
||||
if default_stagger < 0:
|
||||
msg = "Default stagger must be a >= 0."
|
||||
raise ValueError(msg)
|
||||
if period_in_seconds <= 0:
|
||||
msg = "Period in seconds must be a positive integer."
|
||||
raise ValueError(msg)
|
||||
self.rpm = rpm
|
||||
self.tpm = tpm
|
||||
self._lock = threading.Lock()
|
||||
self.rate_queue: deque[float] = deque()
|
||||
self.token_queue: deque[int] = deque()
|
||||
self.period_in_seconds = period_in_seconds
|
||||
self._last_time: float | None = None
|
||||
|
||||
self.stagger = default_stagger
|
||||
if self.rpm is not None and self.rpm > 0:
|
||||
self.stagger = self.period_in_seconds / self.rpm
|
||||
|
||||
@contextmanager
|
||||
def acquire(self, *, token_count: int) -> Iterator[None]:
|
||||
"""
|
||||
Acquire Rate Limiter.
|
||||
|
||||
Args
|
||||
----
|
||||
token_count: The estimated number of tokens for the current request.
|
||||
|
||||
Yields
|
||||
------
|
||||
None: This context manager does not return any value.
|
||||
"""
|
||||
while True:
|
||||
with self._lock:
|
||||
current_time = time.time()
|
||||
|
||||
# Use two sliding windows to keep track of #requests and tokens per period
|
||||
# Drop old requests and tokens out of the sliding windows
|
||||
while (
|
||||
len(self.rate_queue) > 0
|
||||
and self.rate_queue[0] < current_time - self.period_in_seconds
|
||||
):
|
||||
self.rate_queue.popleft()
|
||||
self.token_queue.popleft()
|
||||
|
||||
# If sliding window still exceed request limit, wait again
|
||||
# Waiting requires reacquiring the lock, allowing other threads
|
||||
# to see if their request fits within the rate limiting windows
|
||||
# Makes more sense for token limit than request limit
|
||||
if (
|
||||
self.rpm is not None
|
||||
and self.rpm > 0
|
||||
and len(self.rate_queue) >= self.rpm
|
||||
):
|
||||
continue
|
||||
|
||||
# Check if current token window exceeds token limit
|
||||
# If it does, wait again
|
||||
# This does not account for the tokens from the current request
|
||||
# This is intentional, as we want to allow the current request
|
||||
# to be processed if it is larger than the tpm but smaller than context window.
|
||||
# tpm is a rate/soft limit and not the hard limit of context window limits.
|
||||
if (
|
||||
self.tpm is not None
|
||||
and self.tpm > 0
|
||||
and sum(self.token_queue) >= self.tpm
|
||||
):
|
||||
continue
|
||||
|
||||
# This check accounts for the current request token usage
|
||||
# is within the token limits bound.
|
||||
# If the current requests token limit exceeds the token limit,
|
||||
# Then let it be processed.
|
||||
if (
|
||||
self.tpm is not None
|
||||
and self.tpm > 0
|
||||
and token_count <= self.tpm
|
||||
and sum(self.token_queue) + token_count > self.tpm
|
||||
):
|
||||
continue
|
||||
|
||||
# If there was a previous call, check if we need to stagger
|
||||
if (
|
||||
self.stagger > 0
|
||||
and (
|
||||
self._last_time # is None if this is the first hit to the rate limiter
|
||||
and current_time - self._last_time
|
||||
< self.stagger # If more time has passed than the stagger time, we can proceed
|
||||
)
|
||||
):
|
||||
time.sleep(self.stagger - (current_time - self._last_time))
|
||||
current_time = time.time()
|
||||
|
||||
# Add the current request to the sliding window
|
||||
self.rate_queue.append(current_time)
|
||||
self.token_queue.append(token_count)
|
||||
self._last_time = current_time
|
||||
break
|
||||
yield
|
||||
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM Retry Services."""
|
||||
@ -0,0 +1,83 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM Exponential Retry Service."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from graphrag.language_model.providers.litellm.services.retry.retry import Retry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExponentialRetry(Retry):
|
||||
"""LiteLLM Exponential Retry Service."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_attempts: int = 5,
|
||||
base_delay: float = 2.0,
|
||||
jitter: bool = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if max_attempts <= 0:
|
||||
msg = "max_attempts must be greater than 0."
|
||||
raise ValueError(msg)
|
||||
|
||||
if base_delay <= 1.0:
|
||||
msg = "base_delay must be greater than 1.0."
|
||||
raise ValueError(msg)
|
||||
|
||||
self._max_attempts = max_attempts
|
||||
self._base_delay = base_delay
|
||||
self._jitter = jitter
|
||||
|
||||
def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any:
|
||||
"""Retry a synchronous function."""
|
||||
retries = 0
|
||||
delay = 1.0 # Initial delay in seconds
|
||||
while True:
|
||||
try:
|
||||
return func(**kwargs)
|
||||
except Exception as e:
|
||||
if retries >= self._max_attempts:
|
||||
logger.exception(
|
||||
f"ExponentialRetry: Max retries exceeded, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
|
||||
)
|
||||
raise
|
||||
retries += 1
|
||||
delay *= self._base_delay
|
||||
logger.exception(
|
||||
f"ExponentialRetry: Request failed, retrying, retries={retries}, delay={delay}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
|
||||
)
|
||||
time.sleep(delay + (self._jitter * random.uniform(0, 1))) # noqa: S311
|
||||
|
||||
async def aretry(
|
||||
self,
|
||||
func: Callable[..., Awaitable[Any]],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Retry an asynchronous function."""
|
||||
retries = 0
|
||||
delay = 1.0 # Initial delay in seconds
|
||||
while True:
|
||||
try:
|
||||
return await func(**kwargs)
|
||||
except Exception as e:
|
||||
if retries >= self._max_attempts:
|
||||
logger.exception(
|
||||
f"ExponentialRetry: Max retries exceeded, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
|
||||
)
|
||||
raise
|
||||
retries += 1
|
||||
delay *= self._base_delay
|
||||
logger.exception(
|
||||
f"ExponentialRetry: Request failed, retrying, retries={retries}, delay={delay}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
|
||||
)
|
||||
await asyncio.sleep(delay + (self._jitter * random.uniform(0, 1))) # noqa: S311
|
||||
@ -0,0 +1,77 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM Incremental Wait Retry Service."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from graphrag.language_model.providers.litellm.services.retry.retry import Retry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IncrementalWaitRetry(Retry):
|
||||
"""LiteLLM Incremental Wait Retry Service."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_retry_wait: float,
|
||||
max_attempts: int = 5,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if max_attempts <= 0 or max_retry_wait <= 0:
|
||||
msg = "max_attempts and max_retry_wait must be greater than 0."
|
||||
raise ValueError(msg)
|
||||
|
||||
self._max_attempts = max_attempts
|
||||
self._max_retry_wait = max_retry_wait
|
||||
self._increment = max_retry_wait / max_attempts
|
||||
|
||||
def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any:
|
||||
"""Retry a synchronous function."""
|
||||
retries = 0
|
||||
delay = 0.0
|
||||
while True:
|
||||
try:
|
||||
return func(**kwargs)
|
||||
except Exception as e:
|
||||
if retries >= self._max_attempts:
|
||||
logger.exception(
|
||||
f"IncrementalWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
|
||||
)
|
||||
raise
|
||||
retries += 1
|
||||
delay += self._increment
|
||||
logger.exception(
|
||||
f"IncrementalWaitRetry: Request failed, retrying after incremental delay, retries={retries}, delay={delay}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
async def aretry(
|
||||
self,
|
||||
func: Callable[..., Awaitable[Any]],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Retry an asynchronous function."""
|
||||
retries = 0
|
||||
delay = 0.0
|
||||
while True:
|
||||
try:
|
||||
return await func(**kwargs)
|
||||
except Exception as e:
|
||||
if retries >= self._max_attempts:
|
||||
logger.exception(
|
||||
f"IncrementalWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
|
||||
)
|
||||
raise
|
||||
retries += 1
|
||||
delay += self._increment
|
||||
logger.exception(
|
||||
f"IncrementalWaitRetry: Request failed, retrying after incremental delay, retries={retries}, delay={delay}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
@ -0,0 +1,66 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM Native Retry Service."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from graphrag.language_model.providers.litellm.services.retry.retry import Retry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NativeRetry(Retry):
|
||||
"""LiteLLM Native Retry Service."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_attempts: int = 5,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if max_attempts <= 0:
|
||||
msg = "max_attempts must be greater than 0."
|
||||
raise ValueError(msg)
|
||||
|
||||
self._max_attempts = max_attempts
|
||||
|
||||
def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any:
|
||||
"""Retry a synchronous function."""
|
||||
retries = 0
|
||||
while True:
|
||||
try:
|
||||
return func(**kwargs)
|
||||
except Exception as e:
|
||||
if retries >= self._max_attempts:
|
||||
logger.exception(
|
||||
f"NativeRetry: Max retries exceeded, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
|
||||
)
|
||||
raise
|
||||
retries += 1
|
||||
logger.exception(
|
||||
f"NativeRetry: Request failed, immediately retrying, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
|
||||
)
|
||||
|
||||
async def aretry(
|
||||
self,
|
||||
func: Callable[..., Awaitable[Any]],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Retry an asynchronous function."""
|
||||
retries = 0
|
||||
while True:
|
||||
try:
|
||||
return await func(**kwargs)
|
||||
except Exception as e:
|
||||
if retries >= self._max_attempts:
|
||||
logger.exception(
|
||||
f"NativeRetry: Max retries exceeded, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
|
||||
)
|
||||
raise
|
||||
retries += 1
|
||||
logger.exception(
|
||||
f"NativeRetry: Request failed, immediately retrying, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
|
||||
)
|
||||
@ -0,0 +1,75 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM Random Wait Retry Service."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from graphrag.language_model.providers.litellm.services.retry.retry import Retry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RandomWaitRetry(Retry):
|
||||
"""LiteLLM Random Wait Retry Service."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_retry_wait: float,
|
||||
max_attempts: int = 5,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if max_attempts <= 0 or max_retry_wait <= 0:
|
||||
msg = "max_attempts and max_retry_wait must be greater than 0."
|
||||
raise ValueError(msg)
|
||||
|
||||
self._max_attempts = max_attempts
|
||||
self._max_retry_wait = max_retry_wait
|
||||
|
||||
def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any:
|
||||
"""Retry a synchronous function."""
|
||||
retries = 0
|
||||
while True:
|
||||
try:
|
||||
return func(**kwargs)
|
||||
except Exception as e:
|
||||
if retries >= self._max_attempts:
|
||||
logger.exception(
|
||||
f"RandomWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
|
||||
)
|
||||
raise
|
||||
retries += 1
|
||||
delay = random.uniform(0, self._max_retry_wait) # noqa: S311
|
||||
logger.exception(
|
||||
f"RandomWaitRetry: Request failed, retrying after random delay, retries={retries}, delay={delay}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
async def aretry(
|
||||
self,
|
||||
func: Callable[..., Awaitable[Any]],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Retry an asynchronous function."""
|
||||
retries = 0
|
||||
while True:
|
||||
try:
|
||||
return await func(**kwargs)
|
||||
except Exception as e:
|
||||
if retries >= self._max_attempts:
|
||||
logger.exception(
|
||||
f"RandomWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
|
||||
)
|
||||
raise
|
||||
retries += 1
|
||||
delay = random.uniform(0, self._max_retry_wait) # noqa: S311
|
||||
logger.exception(
|
||||
f"RandomWaitRetry: Request failed, retrying after random delay, retries={retries}, delay={delay}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
@ -0,0 +1,33 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM Retry Abstract Base Class."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Retry(ABC):
|
||||
"""LiteLLM Retry Abstract Base Class."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, /, **kwargs: Any):
|
||||
msg = "Retry subclasses must implement the __init__ method."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
@abstractmethod
|
||||
def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any:
|
||||
"""Retry a synchronous function."""
|
||||
msg = "Subclasses must implement this method"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
@abstractmethod
|
||||
async def aretry(
|
||||
self,
|
||||
func: Callable[..., Awaitable[Any]],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Retry an asynchronous function."""
|
||||
msg = "Subclasses must implement this method"
|
||||
raise NotImplementedError(msg)
|
||||
@ -0,0 +1,18 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM Retry Factory."""
|
||||
|
||||
from graphrag.config.defaults import DEFAULT_RETRY_SERVICES
|
||||
from graphrag.factory.factory import Factory
|
||||
from graphrag.language_model.providers.litellm.services.retry.retry import Retry
|
||||
|
||||
|
||||
class RetryFactory(Factory[Retry]):
|
||||
"""Singleton factory for creating retry services."""
|
||||
|
||||
|
||||
retry_factory = RetryFactory()
|
||||
|
||||
for service_name, service_cls in DEFAULT_RETRY_SERVICES.items():
|
||||
retry_factory.register(strategy=service_name, service_initializer=service_cls)
|
||||
235
graphrag/language_model/providers/litellm/types.py
Normal file
235
graphrag/language_model/providers/litellm/types.py
Normal file
@ -0,0 +1,235 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM types."""
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from litellm import (
|
||||
AnthropicThinkingParam,
|
||||
BaseModel,
|
||||
ChatCompletionAudioParam,
|
||||
ChatCompletionModality,
|
||||
ChatCompletionPredictionContentParam,
|
||||
CustomStreamWrapper,
|
||||
EmbeddingResponse, # type: ignore
|
||||
ModelResponse, # type: ignore
|
||||
OpenAIWebSearchOptions,
|
||||
)
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta
|
||||
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
||||
from openai.types.completion_usage import (
|
||||
CompletionTokensDetails,
|
||||
CompletionUsage,
|
||||
PromptTokensDetails,
|
||||
)
|
||||
from openai.types.create_embedding_response import CreateEmbeddingResponse, Usage
|
||||
from openai.types.embedding import Embedding
|
||||
|
||||
LMChatCompletionMessageParam = ChatCompletionMessageParam | dict[str, str]
|
||||
|
||||
LMChatCompletion = ChatCompletion
|
||||
LMChoice = Choice
|
||||
LMChatCompletionMessage = ChatCompletionMessage
|
||||
|
||||
LMChatCompletionChunk = ChatCompletionChunk
|
||||
LMChoiceChunk = ChunkChoice
|
||||
LMChoiceDelta = ChoiceDelta
|
||||
|
||||
LMCompletionUsage = CompletionUsage
|
||||
LMPromptTokensDetails = PromptTokensDetails
|
||||
LMCompletionTokensDetails = CompletionTokensDetails
|
||||
|
||||
|
||||
LMEmbeddingResponse = CreateEmbeddingResponse
|
||||
LMEmbedding = Embedding
|
||||
LMEmbeddingUsage = Usage
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class FixedModelCompletion(Protocol):
|
||||
"""
|
||||
Synchronous chat completion function.
|
||||
|
||||
Same signature as litellm.completion but without the `model` parameter
|
||||
as this is already set in the model configuration.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
messages: list = [], # type: ignore # noqa: B006
|
||||
stream: bool | None = None,
|
||||
stream_options: dict | None = None, # type: ignore
|
||||
stop=None, # type: ignore
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
modalities: list[ChatCompletionModality] | None = None,
|
||||
prediction: ChatCompletionPredictionContentParam | None = None,
|
||||
audio: ChatCompletionAudioParam | None = None,
|
||||
logit_bias: dict | None = None, # type: ignore
|
||||
user: str | None = None,
|
||||
# openai v1.0+ new params
|
||||
response_format: dict | type[BaseModel] | None = None, # type: ignore
|
||||
seed: int | None = None,
|
||||
tools: list | None = None, # type: ignore
|
||||
tool_choice: str | dict | None = None, # type: ignore
|
||||
logprobs: bool | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
web_search_options: OpenAIWebSearchOptions | None = None,
|
||||
deployment_id=None, # type: ignore
|
||||
extra_headers: dict | None = None, # type: ignore
|
||||
# soon to be deprecated params by OpenAI
|
||||
functions: list | None = None, # type: ignore
|
||||
function_call: str | None = None,
|
||||
# Optional liteLLM function params
|
||||
thinking: AnthropicThinkingParam | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ModelResponse | CustomStreamWrapper:
|
||||
"""Chat completion function."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AFixedModelCompletion(Protocol):
|
||||
"""
|
||||
Asynchronous chat completion function.
|
||||
|
||||
Same signature as litellm.acompletion but without the `model` parameter
|
||||
as this is already set in the model configuration.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: list = [], # type: ignore # noqa: B006
|
||||
stream: bool | None = None,
|
||||
stream_options: dict | None = None, # type: ignore
|
||||
stop=None, # type: ignore
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
modalities: list[ChatCompletionModality] | None = None,
|
||||
prediction: ChatCompletionPredictionContentParam | None = None,
|
||||
audio: ChatCompletionAudioParam | None = None,
|
||||
logit_bias: dict | None = None, # type: ignore
|
||||
user: str | None = None,
|
||||
# openai v1.0+ new params
|
||||
response_format: dict | type[BaseModel] | None = None, # type: ignore
|
||||
seed: int | None = None,
|
||||
tools: list | None = None, # type: ignore
|
||||
tool_choice: str | dict | None = None, # type: ignore
|
||||
logprobs: bool | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
web_search_options: OpenAIWebSearchOptions | None = None,
|
||||
deployment_id=None, # type: ignore
|
||||
extra_headers: dict | None = None, # type: ignore
|
||||
# soon to be deprecated params by OpenAI
|
||||
functions: list | None = None, # type: ignore
|
||||
function_call: str | None = None,
|
||||
# Optional liteLLM function params
|
||||
thinking: AnthropicThinkingParam | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ModelResponse | CustomStreamWrapper:
|
||||
"""Chat completion function."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class FixedModelEmbedding(Protocol):
|
||||
"""
|
||||
Synchronous embedding function.
|
||||
|
||||
Same signature as litellm.embedding but without the `model` parameter
|
||||
as this is already set in the model configuration.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
request_id: str | None = None,
|
||||
input: list = [], # type: ignore # noqa: B006
|
||||
# Optional params
|
||||
dimensions: int | None = None,
|
||||
encoding_format: str | None = None,
|
||||
timeout: int = 600, # default to 10 minutes
|
||||
# set api_base, api_version, api_key
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_type: str | None = None,
|
||||
caching: bool = False,
|
||||
user: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""Embedding function."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AFixedModelEmbedding(Protocol):
|
||||
"""
|
||||
Asynchronous embedding function.
|
||||
|
||||
Same signature as litellm.embedding but without the `model` parameter
|
||||
as this is already set in the model configuration.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
request_id: str | None = None,
|
||||
input: list = [], # type: ignore # noqa: B006
|
||||
# Optional params
|
||||
dimensions: int | None = None,
|
||||
encoding_format: str | None = None,
|
||||
timeout: int = 600, # default to 10 minutes
|
||||
# set api_base, api_version, api_key
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_type: str | None = None,
|
||||
caching: bool = False,
|
||||
user: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""Embedding function."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LitellmRequestFunc(Protocol):
|
||||
"""
|
||||
Synchronous request function.
|
||||
|
||||
Represents either a chat completion or embedding function.
|
||||
"""
|
||||
|
||||
def __call__(self, /, **kwargs: Any) -> Any:
|
||||
"""Request function."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AsyncLitellmRequestFunc(Protocol):
|
||||
"""
|
||||
Asynchronous request function.
|
||||
|
||||
Represents either a chat completion or embedding function.
|
||||
"""
|
||||
|
||||
async def __call__(self, /, **kwargs: Any) -> Any:
|
||||
"""Request function."""
|
||||
...
|
||||
@ -5,8 +5,6 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.index.utils.tokens import num_tokens_from_string
|
||||
from graphrag.prompt_tune.template.extract_graph import (
|
||||
EXAMPLE_EXTRACTION_TEMPLATE,
|
||||
GRAPH_EXTRACTION_JSON_PROMPT,
|
||||
@ -14,6 +12,8 @@ from graphrag.prompt_tune.template.extract_graph import (
|
||||
UNTYPED_EXAMPLE_EXTRACTION_TEMPLATE,
|
||||
UNTYPED_GRAPH_EXTRACTION_PROMPT,
|
||||
)
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
EXTRACT_GRAPH_FILENAME = "extract_graph.txt"
|
||||
|
||||
@ -24,7 +24,7 @@ def create_extract_graph_prompt(
|
||||
examples: list[str],
|
||||
language: str,
|
||||
max_token_count: int,
|
||||
encoding_model: str = defs.ENCODING_MODEL,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
json_mode: bool = False,
|
||||
output_path: Path | None = None,
|
||||
min_examples_required: int = 2,
|
||||
@ -38,7 +38,7 @@ def create_extract_graph_prompt(
|
||||
- docs (list[str]): The list of documents to extract entities from
|
||||
- examples (list[str]): The list of examples to use for entity extraction
|
||||
- language (str): The language of the inputs and outputs
|
||||
- encoding_model (str): The name of the model to use for token counting
|
||||
- tokenizer (Tokenizer): The tokenizer to use for encoding and decoding text.
|
||||
- max_token_count (int): The maximum number of tokens to use for the prompt
|
||||
- json_mode (bool): Whether to use JSON mode for the prompt. Default is False
|
||||
- output_path (Path | None): The path to write the prompt to. Default is None.
|
||||
@ -56,10 +56,12 @@ def create_extract_graph_prompt(
|
||||
if isinstance(entity_types, list):
|
||||
entity_types = ", ".join(map(str, entity_types))
|
||||
|
||||
tokenizer = tokenizer or get_tokenizer()
|
||||
|
||||
tokens_left = (
|
||||
max_token_count
|
||||
- num_tokens_from_string(prompt, encoding_name=encoding_model)
|
||||
- num_tokens_from_string(entity_types, encoding_name=encoding_model)
|
||||
- tokenizer.num_tokens(prompt)
|
||||
- tokenizer.num_tokens(entity_types)
|
||||
if entity_types
|
||||
else 0
|
||||
)
|
||||
@ -79,9 +81,7 @@ def create_extract_graph_prompt(
|
||||
)
|
||||
)
|
||||
|
||||
example_tokens = num_tokens_from_string(
|
||||
example_formatted, encoding_name=encoding_model
|
||||
)
|
||||
example_tokens = tokenizer.num_tokens(example_formatted)
|
||||
|
||||
# Ensure at least three examples are included
|
||||
if i >= min_examples_required and example_tokens > tokens_left:
|
||||
|
||||
@ -8,11 +8,11 @@ import random
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
|
||||
from graphrag.data_model.community_report import CommunityReport
|
||||
from graphrag.data_model.entity import Entity
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -24,7 +24,7 @@ NO_COMMUNITY_RECORDS_WARNING: str = (
|
||||
def build_community_context(
|
||||
community_reports: list[CommunityReport],
|
||||
entities: list[Entity] | None = None,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
use_community_summary: bool = True,
|
||||
column_delimiter: str = "|",
|
||||
shuffle_data: bool = True,
|
||||
@ -46,6 +46,7 @@ def build_community_context(
|
||||
|
||||
The calculated weight is added as an attribute to the community reports and added to the context data table.
|
||||
"""
|
||||
tokenizer = tokenizer or get_tokenizer()
|
||||
|
||||
def _is_included(report: CommunityReport) -> bool:
|
||||
return report.rank is not None and report.rank >= min_community_rank
|
||||
@ -125,7 +126,7 @@ def build_community_context(
|
||||
batch_text = (
|
||||
f"-----{context_name}-----" + "\n" + column_delimiter.join(header) + "\n"
|
||||
)
|
||||
batch_tokens = num_tokens(batch_text, token_encoder)
|
||||
batch_tokens = tokenizer.num_tokens(batch_text)
|
||||
batch_records = []
|
||||
|
||||
def _cut_batch() -> None:
|
||||
@ -152,7 +153,7 @@ def build_community_context(
|
||||
|
||||
for report in selected_reports:
|
||||
new_context_text, new_context = _report_context_text(report, attributes)
|
||||
new_tokens = num_tokens(new_context_text, token_encoder)
|
||||
new_tokens = tokenizer.num_tokens(new_context_text)
|
||||
|
||||
if batch_tokens + new_tokens > max_context_tokens:
|
||||
# add the current batch to the context data and start a new batch if we are in multi-batch mode
|
||||
|
||||
@ -7,9 +7,9 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
"""
|
||||
Enum for conversation roles
|
||||
@ -148,7 +148,7 @@ class ConversationHistory:
|
||||
|
||||
def build_context(
|
||||
self,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
include_user_turns_only: bool = True,
|
||||
max_qa_turns: int | None = 5,
|
||||
max_context_tokens: int = 8000,
|
||||
@ -168,6 +168,7 @@ class ConversationHistory:
|
||||
context_name: Name of the context, default is "Conversation History".
|
||||
|
||||
"""
|
||||
tokenizer = tokenizer or get_tokenizer()
|
||||
qa_turns = self.to_qa_turns()
|
||||
if include_user_turns_only:
|
||||
qa_turns = [
|
||||
@ -202,7 +203,7 @@ class ConversationHistory:
|
||||
|
||||
context_df = pd.DataFrame(turn_list)
|
||||
context_text = header + context_df.to_csv(sep=column_delimiter, index=False)
|
||||
if num_tokens(context_text, token_encoder) > max_context_tokens:
|
||||
if tokenizer.num_tokens(context_text) > max_context_tokens:
|
||||
break
|
||||
|
||||
current_context_df = context_df
|
||||
|
||||
@ -10,13 +10,12 @@ from copy import deepcopy
|
||||
from time import time
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
|
||||
from graphrag.data_model.community import Community
|
||||
from graphrag.data_model.community_report import CommunityReport
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.query.context_builder.rate_prompt import RATE_QUERY
|
||||
from graphrag.query.context_builder.rate_relevancy import rate_relevancy
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -32,7 +31,7 @@ class DynamicCommunitySelection:
|
||||
community_reports: list[CommunityReport],
|
||||
communities: list[Community],
|
||||
model: ChatModel,
|
||||
token_encoder: tiktoken.Encoding,
|
||||
tokenizer: Tokenizer,
|
||||
rate_query: str = RATE_QUERY,
|
||||
use_summary: bool = False,
|
||||
threshold: int = 1,
|
||||
@ -43,7 +42,7 @@ class DynamicCommunitySelection:
|
||||
model_params: dict[str, Any] | None = None,
|
||||
):
|
||||
self.model = model
|
||||
self.token_encoder = token_encoder
|
||||
self.tokenizer = tokenizer
|
||||
self.rate_query = rate_query
|
||||
self.num_repeats = num_repeats
|
||||
self.use_summary = use_summary
|
||||
@ -97,7 +96,7 @@ class DynamicCommunitySelection:
|
||||
else self.reports[community].full_content
|
||||
),
|
||||
model=self.model,
|
||||
token_encoder=self.token_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
rate_query=self.rate_query,
|
||||
num_repeats=self.num_repeats,
|
||||
semaphore=self.semaphore,
|
||||
|
||||
@ -7,7 +7,6 @@ from collections import defaultdict
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
|
||||
from graphrag.data_model.covariate import Covariate
|
||||
from graphrag.data_model.entity import Entity
|
||||
@ -24,12 +23,13 @@ from graphrag.query.input.retrieval.relationships import (
|
||||
get_out_network_relationships,
|
||||
to_relationship_dataframe,
|
||||
)
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
|
||||
def build_entity_context(
|
||||
selected_entities: list[Entity],
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
max_context_tokens: int = 8000,
|
||||
include_entity_rank: bool = True,
|
||||
rank_description: str = "number of relationships",
|
||||
@ -37,6 +37,8 @@ def build_entity_context(
|
||||
context_name="Entities",
|
||||
) -> tuple[str, pd.DataFrame]:
|
||||
"""Prepare entity data table as context data for system prompt."""
|
||||
tokenizer = tokenizer or get_tokenizer()
|
||||
|
||||
if len(selected_entities) == 0:
|
||||
return "", pd.DataFrame()
|
||||
|
||||
@ -52,7 +54,7 @@ def build_entity_context(
|
||||
)
|
||||
header.extend(attribute_cols)
|
||||
current_context_text += column_delimiter.join(header) + "\n"
|
||||
current_tokens = num_tokens(current_context_text, token_encoder)
|
||||
current_tokens = tokenizer.num_tokens(current_context_text)
|
||||
|
||||
all_context_records = [header]
|
||||
for entity in selected_entities:
|
||||
@ -71,7 +73,7 @@ def build_entity_context(
|
||||
)
|
||||
new_context.append(field_value)
|
||||
new_context_text = column_delimiter.join(new_context) + "\n"
|
||||
new_tokens = num_tokens(new_context_text, token_encoder)
|
||||
new_tokens = tokenizer.num_tokens(new_context_text)
|
||||
if current_tokens + new_tokens > max_context_tokens:
|
||||
break
|
||||
current_context_text += new_context_text
|
||||
@ -91,12 +93,13 @@ def build_entity_context(
|
||||
def build_covariates_context(
|
||||
selected_entities: list[Entity],
|
||||
covariates: list[Covariate],
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
max_context_tokens: int = 8000,
|
||||
column_delimiter: str = "|",
|
||||
context_name: str = "Covariates",
|
||||
) -> tuple[str, pd.DataFrame]:
|
||||
"""Prepare covariate data tables as context data for system prompt."""
|
||||
tokenizer = tokenizer or get_tokenizer()
|
||||
# create an empty list of covariates
|
||||
if len(selected_entities) == 0 or len(covariates) == 0:
|
||||
return "", pd.DataFrame()
|
||||
@ -113,7 +116,7 @@ def build_covariates_context(
|
||||
attribute_cols = list(attributes.keys()) if len(covariates) > 0 else []
|
||||
header.extend(attribute_cols)
|
||||
current_context_text += column_delimiter.join(header) + "\n"
|
||||
current_tokens = num_tokens(current_context_text, token_encoder)
|
||||
current_tokens = tokenizer.num_tokens(current_context_text)
|
||||
|
||||
all_context_records = [header]
|
||||
for entity in selected_entities:
|
||||
@ -135,7 +138,7 @@ def build_covariates_context(
|
||||
new_context.append(field_value)
|
||||
|
||||
new_context_text = column_delimiter.join(new_context) + "\n"
|
||||
new_tokens = num_tokens(new_context_text, token_encoder)
|
||||
new_tokens = tokenizer.num_tokens(new_context_text)
|
||||
if current_tokens + new_tokens > max_context_tokens:
|
||||
break
|
||||
current_context_text += new_context_text
|
||||
@ -155,7 +158,7 @@ def build_covariates_context(
|
||||
def build_relationship_context(
|
||||
selected_entities: list[Entity],
|
||||
relationships: list[Relationship],
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
include_relationship_weight: bool = False,
|
||||
max_context_tokens: int = 8000,
|
||||
top_k_relationships: int = 10,
|
||||
@ -164,6 +167,7 @@ def build_relationship_context(
|
||||
context_name: str = "Relationships",
|
||||
) -> tuple[str, pd.DataFrame]:
|
||||
"""Prepare relationship data tables as context data for system prompt."""
|
||||
tokenizer = tokenizer or get_tokenizer()
|
||||
selected_relationships = _filter_relationships(
|
||||
selected_entities=selected_entities,
|
||||
relationships=relationships,
|
||||
@ -188,7 +192,7 @@ def build_relationship_context(
|
||||
header.extend(attribute_cols)
|
||||
|
||||
current_context_text += column_delimiter.join(header) + "\n"
|
||||
current_tokens = num_tokens(current_context_text, token_encoder)
|
||||
current_tokens = tokenizer.num_tokens(current_context_text)
|
||||
|
||||
all_context_records = [header]
|
||||
for rel in selected_relationships:
|
||||
@ -208,7 +212,7 @@ def build_relationship_context(
|
||||
)
|
||||
new_context.append(field_value)
|
||||
new_context_text = column_delimiter.join(new_context) + "\n"
|
||||
new_tokens = num_tokens(new_context_text, token_encoder)
|
||||
new_tokens = tokenizer.num_tokens(new_context_text)
|
||||
if current_tokens + new_tokens > max_context_tokens:
|
||||
break
|
||||
current_context_text += new_context_text
|
||||
|
||||
@ -9,11 +9,11 @@ from contextlib import nullcontext
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.query.context_builder.rate_prompt import RATE_QUERY
|
||||
from graphrag.query.llm.text_utils import num_tokens, try_parse_json_object
|
||||
from graphrag.query.llm.text_utils import try_parse_json_object
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -22,7 +22,7 @@ async def rate_relevancy(
|
||||
query: str,
|
||||
description: str,
|
||||
model: ChatModel,
|
||||
token_encoder: tiktoken.Encoding,
|
||||
tokenizer: Tokenizer,
|
||||
rate_query: str = RATE_QUERY,
|
||||
num_repeats: int = 1,
|
||||
semaphore: asyncio.Semaphore | None = None,
|
||||
@ -36,7 +36,7 @@ async def rate_relevancy(
|
||||
description: the community description to rate, it can be the community
|
||||
title, summary, or the full content.
|
||||
llm: LLM model to use for rating
|
||||
token_encoder: token encoder
|
||||
tokenizer: tokenizer
|
||||
num_repeats: number of times to repeat the rating process for the same community (default: 1)
|
||||
model_params: additional arguments to pass to the LLM model
|
||||
semaphore: asyncio.Semaphore to limit the number of concurrent LLM calls (default: None)
|
||||
@ -63,8 +63,8 @@ async def rate_relevancy(
|
||||
logger.warning("Error parsing json response, defaulting to rating 1")
|
||||
ratings.append(1)
|
||||
llm_calls += 1
|
||||
prompt_tokens += num_tokens(messages[0]["content"], token_encoder)
|
||||
output_tokens += num_tokens(response, token_encoder)
|
||||
prompt_tokens += tokenizer.num_tokens(messages[0]["content"])
|
||||
output_tokens += tokenizer.num_tokens(response)
|
||||
# select the decision with the most votes
|
||||
options, counts = np.unique(ratings, return_counts=True)
|
||||
rating = int(options[np.argmax(counts)])
|
||||
|
||||
@ -7,11 +7,11 @@ import random
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
|
||||
from graphrag.data_model.relationship import Relationship
|
||||
from graphrag.data_model.text_unit import TextUnit
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
"""
|
||||
Contain util functions to build text unit context for the search's system prompt
|
||||
@ -20,7 +20,7 @@ Contain util functions to build text unit context for the search's system prompt
|
||||
|
||||
def build_text_unit_context(
|
||||
text_units: list[TextUnit],
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
column_delimiter: str = "|",
|
||||
shuffle_data: bool = True,
|
||||
max_context_tokens: int = 8000,
|
||||
@ -28,6 +28,7 @@ def build_text_unit_context(
|
||||
random_state: int = 86,
|
||||
) -> tuple[str, dict[str, pd.DataFrame]]:
|
||||
"""Prepare text-unit data table as context data for system prompt."""
|
||||
tokenizer = tokenizer or get_tokenizer()
|
||||
if text_units is None or len(text_units) == 0:
|
||||
return ("", {})
|
||||
|
||||
@ -47,7 +48,7 @@ def build_text_unit_context(
|
||||
header.extend(attribute_cols)
|
||||
|
||||
current_context_text += column_delimiter.join(header) + "\n"
|
||||
current_tokens = num_tokens(current_context_text, token_encoder)
|
||||
current_tokens = tokenizer.num_tokens(current_context_text)
|
||||
all_context_records = [header]
|
||||
|
||||
for unit in text_units:
|
||||
@ -60,7 +61,7 @@ def build_text_unit_context(
|
||||
],
|
||||
]
|
||||
new_context_text = column_delimiter.join(new_context) + "\n"
|
||||
new_tokens = num_tokens(new_context_text, token_encoder)
|
||||
new_tokens = tokenizer.num_tokens(new_context_text)
|
||||
|
||||
if current_tokens + new_tokens > max_context_tokens:
|
||||
break
|
||||
|
||||
@ -3,8 +3,6 @@
|
||||
|
||||
"""Query Factory methods to support CLI."""
|
||||
|
||||
import tiktoken
|
||||
|
||||
from graphrag.callbacks.query_callbacks import QueryCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.data_model.community import Community
|
||||
@ -34,6 +32,7 @@ from graphrag.query.structured_search.local_search.mixed_context import (
|
||||
LocalSearchMixedContext,
|
||||
)
|
||||
from graphrag.query.structured_search.local_search.search import LocalSearch
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
|
||||
|
||||
@ -68,7 +67,7 @@ def get_local_search_engine(
|
||||
config=embedding_settings,
|
||||
)
|
||||
|
||||
token_encoder = tiktoken.get_encoding(model_settings.encoding_model)
|
||||
tokenizer = get_tokenizer(model_config=model_settings)
|
||||
|
||||
ls_config = config.local_search
|
||||
|
||||
@ -86,9 +85,9 @@ def get_local_search_engine(
|
||||
entity_text_embeddings=description_embedding_store,
|
||||
embedding_vectorstore_key=EntityVectorStoreKey.ID, # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE
|
||||
text_embedder=embedding_model,
|
||||
token_encoder=token_encoder,
|
||||
tokenizer=tokenizer,
|
||||
),
|
||||
token_encoder=token_encoder,
|
||||
tokenizer=tokenizer,
|
||||
model_params=model_params,
|
||||
context_builder_params={
|
||||
"text_unit_prop": ls_config.text_unit_prop,
|
||||
@ -135,7 +134,7 @@ def get_global_search_engine(
|
||||
model_params = get_openai_model_parameters_from_config(model_settings)
|
||||
|
||||
# Here we get encoding based on specified encoding name
|
||||
token_encoder = tiktoken.get_encoding(model_settings.encoding_model)
|
||||
tokenizer = get_tokenizer(model_config=model_settings)
|
||||
gs_config = config.global_search
|
||||
|
||||
dynamic_community_selection_kwargs = {}
|
||||
@ -144,7 +143,7 @@ def get_global_search_engine(
|
||||
|
||||
dynamic_community_selection_kwargs.update({
|
||||
"model": model,
|
||||
"token_encoder": token_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"keep_parent": gs_config.dynamic_search_keep_parent,
|
||||
"num_repeats": gs_config.dynamic_search_num_repeats,
|
||||
"use_summary": gs_config.dynamic_search_use_summary,
|
||||
@ -163,11 +162,11 @@ def get_global_search_engine(
|
||||
community_reports=reports,
|
||||
communities=communities,
|
||||
entities=entities,
|
||||
token_encoder=token_encoder,
|
||||
tokenizer=tokenizer,
|
||||
dynamic_community_selection=dynamic_community_selection,
|
||||
dynamic_community_selection_kwargs=dynamic_community_selection_kwargs,
|
||||
),
|
||||
token_encoder=token_encoder,
|
||||
tokenizer=tokenizer,
|
||||
max_data_tokens=gs_config.data_max_tokens,
|
||||
map_llm_params={**model_params},
|
||||
reduce_llm_params={**model_params},
|
||||
@ -226,7 +225,7 @@ def get_drift_search_engine(
|
||||
config=embedding_model_settings,
|
||||
)
|
||||
|
||||
token_encoder = tiktoken.get_encoding(chat_model_settings.encoding_model)
|
||||
tokenizer = get_tokenizer(model_config=chat_model_settings)
|
||||
|
||||
return DRIFTSearch(
|
||||
model=chat_model,
|
||||
@ -243,7 +242,7 @@ def get_drift_search_engine(
|
||||
config=config.drift_search,
|
||||
response_type=response_type,
|
||||
),
|
||||
token_encoder=token_encoder,
|
||||
tokenizer=tokenizer,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
@ -277,7 +276,7 @@ def get_basic_search_engine(
|
||||
config=embedding_model_settings,
|
||||
)
|
||||
|
||||
token_encoder = tiktoken.get_encoding(chat_model_settings.encoding_model)
|
||||
tokenizer = get_tokenizer(model_config=chat_model_settings)
|
||||
|
||||
bs_config = config.basic_search
|
||||
|
||||
@ -291,9 +290,9 @@ def get_basic_search_engine(
|
||||
text_embedder=embedding_model,
|
||||
text_unit_embeddings=text_unit_embeddings,
|
||||
text_units=text_units,
|
||||
token_encoder=token_encoder,
|
||||
tokenizer=tokenizer,
|
||||
),
|
||||
token_encoder=token_encoder,
|
||||
tokenizer=tokenizer,
|
||||
model_params=model_params,
|
||||
context_builder_params={
|
||||
"embedding_vectorstore_key": "id",
|
||||
|
||||
@ -9,7 +9,6 @@ from dataclasses import dataclass
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.query.context_builder.builders import (
|
||||
@ -21,6 +20,8 @@ from graphrag.query.context_builder.builders import (
|
||||
from graphrag.query.context_builder.conversation_history import (
|
||||
ConversationHistory,
|
||||
)
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -58,13 +59,13 @@ class BaseSearch(ABC, Generic[T]):
|
||||
self,
|
||||
model: ChatModel,
|
||||
context_builder: T,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
model_params: dict[str, Any] | None = None,
|
||||
context_builder_params: dict[str, Any] | None = None,
|
||||
):
|
||||
self.model = model
|
||||
self.context_builder = context_builder
|
||||
self.token_encoder = token_encoder
|
||||
self.tokenizer = tokenizer or get_tokenizer()
|
||||
self.model_params = model_params or {}
|
||||
self.context_builder_params = context_builder_params or {}
|
||||
|
||||
|
||||
@ -7,7 +7,6 @@ import logging
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
|
||||
from graphrag.data_model.text_unit import TextUnit
|
||||
from graphrag.language_model.protocol.base import EmbeddingModel
|
||||
@ -16,7 +15,8 @@ from graphrag.query.context_builder.builders import (
|
||||
ContextBuilderResult,
|
||||
)
|
||||
from graphrag.query.context_builder.conversation_history import ConversationHistory
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -30,11 +30,11 @@ class BasicSearchContext(BasicContextBuilder):
|
||||
text_embedder: EmbeddingModel,
|
||||
text_unit_embeddings: BaseVectorStore,
|
||||
text_units: list[TextUnit] | None = None,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
embedding_vectorstore_key: str = "id",
|
||||
):
|
||||
self.text_embedder = text_embedder
|
||||
self.token_encoder = token_encoder
|
||||
self.tokenizer = tokenizer or get_tokenizer()
|
||||
self.text_units = text_units
|
||||
self.text_unit_embeddings = text_unit_embeddings
|
||||
self.embedding_vectorstore_key = embedding_vectorstore_key
|
||||
@ -76,12 +76,12 @@ class BasicSearchContext(BasicContextBuilder):
|
||||
# add these related text chunks into context until we fill up the context window
|
||||
current_tokens = 0
|
||||
text_ids = []
|
||||
current_tokens = num_tokens(
|
||||
text_id_col + column_delimiter + text_col + "\n", self.token_encoder
|
||||
current_tokens = len(
|
||||
self.tokenizer.encode(text_id_col + column_delimiter + text_col + "\n")
|
||||
)
|
||||
for i, row in related_text_df.iterrows():
|
||||
text = row[text_id_col] + column_delimiter + row[text_col] + "\n"
|
||||
tokens = num_tokens(text, self.token_encoder)
|
||||
tokens = len(self.tokenizer.encode(text))
|
||||
if current_tokens + tokens > max_context_tokens:
|
||||
msg = f"Reached token limit: {current_tokens + tokens}. Reverting to previous context state"
|
||||
logger.warning(msg)
|
||||
|
||||
@ -8,8 +8,6 @@ import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
|
||||
from graphrag.callbacks.query_callbacks import QueryCallbacks
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.prompts.query.basic_search_system_prompt import (
|
||||
@ -17,8 +15,8 @@ from graphrag.prompts.query.basic_search_system_prompt import (
|
||||
)
|
||||
from graphrag.query.context_builder.builders import BasicContextBuilder
|
||||
from graphrag.query.context_builder.conversation_history import ConversationHistory
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.query.structured_search.base import BaseSearch, SearchResult
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
"""
|
||||
@ -33,7 +31,7 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
|
||||
self,
|
||||
model: ChatModel,
|
||||
context_builder: BasicContextBuilder,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
system_prompt: str | None = None,
|
||||
response_type: str = "multiple paragraphs",
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
@ -43,7 +41,7 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
|
||||
super().__init__(
|
||||
model=model,
|
||||
context_builder=context_builder,
|
||||
token_encoder=token_encoder,
|
||||
tokenizer=tokenizer,
|
||||
model_params=model_params,
|
||||
context_builder_params=context_builder_params or {},
|
||||
)
|
||||
@ -94,8 +92,8 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
|
||||
response += chunk
|
||||
|
||||
llm_calls["response"] = 1
|
||||
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
|
||||
output_tokens["response"] = num_tokens(response, self.token_encoder)
|
||||
prompt_tokens["response"] = len(self.tokenizer.encode(search_prompt))
|
||||
output_tokens["response"] = len(self.tokenizer.encode(response))
|
||||
|
||||
for callback in self.callbacks:
|
||||
callback.on_context(context_result.context_records)
|
||||
@ -106,7 +104,7 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
|
||||
context_text=context_result.context_chunks,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
prompt_tokens=len(self.tokenizer.encode(search_prompt)),
|
||||
output_tokens=sum(output_tokens.values()),
|
||||
llm_calls_categories=llm_calls,
|
||||
prompt_tokens_categories=prompt_tokens,
|
||||
@ -121,7 +119,7 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
|
||||
context_text=context_result.context_chunks,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
prompt_tokens=len(self.tokenizer.encode(search_prompt)),
|
||||
output_tokens=0,
|
||||
llm_calls_categories=llm_calls,
|
||||
prompt_tokens_categories=prompt_tokens,
|
||||
|
||||
@ -9,7 +9,6 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
|
||||
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
|
||||
from graphrag.data_model.community_report import CommunityReport
|
||||
@ -28,6 +27,8 @@ from graphrag.query.structured_search.drift_search.primer import PrimerQueryProc
|
||||
from graphrag.query.structured_search.local_search.mixed_context import (
|
||||
LocalSearchMixedContext,
|
||||
)
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -46,7 +47,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
reports: list[CommunityReport] | None = None,
|
||||
relationships: list[Relationship] | None = None,
|
||||
covariates: dict[str, list[Covariate]] | None = None,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
embedding_vectorstore_key: str = EntityVectorStoreKey.ID,
|
||||
config: DRIFTSearchConfig | None = None,
|
||||
local_system_prompt: str | None = None,
|
||||
@ -58,7 +59,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
self.config = config or DRIFTSearchConfig()
|
||||
self.model = model
|
||||
self.text_embedder = text_embedder
|
||||
self.token_encoder = token_encoder
|
||||
self.tokenizer = tokenizer or get_tokenizer()
|
||||
self.local_system_prompt = local_system_prompt or DRIFT_LOCAL_SYSTEM_PROMPT
|
||||
self.reduce_system_prompt = reduce_system_prompt or DRIFT_REDUCE_PROMPT
|
||||
|
||||
@ -93,7 +94,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
entity_text_embeddings=self.entity_text_embeddings,
|
||||
embedding_vectorstore_key=self.embedding_vectorstore_key,
|
||||
text_embedder=self.text_embedder,
|
||||
token_encoder=self.token_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -192,7 +193,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
query_processor = PrimerQueryProcessor(
|
||||
chat_model=self.model,
|
||||
text_embedder=self.text_embedder,
|
||||
token_encoder=self.token_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
reports=self.reports,
|
||||
)
|
||||
|
||||
|
||||
@ -10,7 +10,6 @@ import time
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
|
||||
@ -19,8 +18,9 @@ from graphrag.language_model.protocol.base import ChatModel, EmbeddingModel
|
||||
from graphrag.prompts.query.drift_search_system_prompt import (
|
||||
DRIFT_PRIMER_PROMPT,
|
||||
)
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.query.structured_search.base import SearchResult
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -33,7 +33,7 @@ class PrimerQueryProcessor:
|
||||
chat_model: ChatModel,
|
||||
text_embedder: EmbeddingModel,
|
||||
reports: list[CommunityReport],
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the PrimerQueryProcessor.
|
||||
@ -42,11 +42,11 @@ class PrimerQueryProcessor:
|
||||
chat_llm (ChatOpenAI): The language model used to process the query.
|
||||
text_embedder (BaseTextEmbedding): The text embedding model.
|
||||
reports (list[CommunityReport]): List of community reports.
|
||||
token_encoder (tiktoken.Encoding, optional): Token encoder for token counting.
|
||||
tokenizer (Tokenizer, optional): Token encoder for token counting.
|
||||
"""
|
||||
self.chat_model = chat_model
|
||||
self.text_embedder = text_embedder
|
||||
self.token_encoder = token_encoder
|
||||
self.tokenizer = tokenizer or get_tokenizer()
|
||||
self.reports = reports
|
||||
|
||||
async def expand_query(self, query: str) -> tuple[str, dict[str, int]]:
|
||||
@ -70,8 +70,8 @@ class PrimerQueryProcessor:
|
||||
model_response = await self.chat_model.achat(prompt)
|
||||
text = model_response.output.content
|
||||
|
||||
prompt_tokens = num_tokens(prompt, self.token_encoder)
|
||||
output_tokens = num_tokens(text, self.token_encoder)
|
||||
prompt_tokens = len(self.tokenizer.encode(prompt))
|
||||
output_tokens = len(self.tokenizer.encode(text))
|
||||
token_ct = {
|
||||
"llm_calls": 1,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
@ -105,7 +105,7 @@ class DRIFTPrimer:
|
||||
self,
|
||||
config: DRIFTSearchConfig,
|
||||
chat_model: ChatModel,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the DRIFTPrimer.
|
||||
@ -113,11 +113,11 @@ class DRIFTPrimer:
|
||||
Args:
|
||||
config (DRIFTSearchConfig): Configuration settings for DRIFT search.
|
||||
chat_llm (ChatOpenAI): The language model used for searching.
|
||||
token_encoder (tiktoken.Encoding, optional): Token encoder for managing tokens.
|
||||
tokenizer (Tokenizer, optional): Tokenizer for managing tokens.
|
||||
"""
|
||||
self.chat_model = chat_model
|
||||
self.config = config
|
||||
self.token_encoder = token_encoder
|
||||
self.tokenizer = tokenizer or get_tokenizer()
|
||||
|
||||
async def decompose_query(
|
||||
self, query: str, reports: pd.DataFrame
|
||||
@ -144,8 +144,8 @@ class DRIFTPrimer:
|
||||
|
||||
token_ct = {
|
||||
"llm_calls": 1,
|
||||
"prompt_tokens": num_tokens(prompt, self.token_encoder),
|
||||
"output_tokens": num_tokens(response, self.token_encoder),
|
||||
"prompt_tokens": len(self.tokenizer.encode(prompt)),
|
||||
"output_tokens": len(self.tokenizer.encode(response)),
|
||||
}
|
||||
|
||||
return parsed_response, token_ct
|
||||
|
||||
@ -8,7 +8,6 @@ import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from graphrag.callbacks.query_callbacks import QueryCallbacks
|
||||
@ -18,7 +17,6 @@ from graphrag.language_model.providers.fnllm.utils import (
|
||||
)
|
||||
from graphrag.query.context_builder.conversation_history import ConversationHistory
|
||||
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.query.structured_search.base import BaseSearch, SearchResult
|
||||
from graphrag.query.structured_search.drift_search.action import DriftAction
|
||||
from graphrag.query.structured_search.drift_search.drift_context import (
|
||||
@ -27,6 +25,8 @@ from graphrag.query.structured_search.drift_search.drift_context import (
|
||||
from graphrag.query.structured_search.drift_search.primer import DRIFTPrimer
|
||||
from graphrag.query.structured_search.drift_search.state import QueryState
|
||||
from graphrag.query.structured_search.local_search.search import LocalSearch
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -38,7 +38,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
self,
|
||||
model: ChatModel,
|
||||
context_builder: DRIFTSearchContextBuilder,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
query_state: QueryState | None = None,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
):
|
||||
@ -49,18 +49,18 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
llm (ChatOpenAI): The language model used for searching.
|
||||
context_builder (DRIFTSearchContextBuilder): Builder for search context.
|
||||
config (DRIFTSearchConfig, optional): Configuration settings for DRIFTSearch.
|
||||
token_encoder (tiktoken.Encoding, optional): Token encoder for managing tokens.
|
||||
tokenizer (Tokenizer, optional): Token encoder for managing tokens.
|
||||
query_state (QueryState, optional): State of the current search query.
|
||||
"""
|
||||
super().__init__(model, context_builder, token_encoder)
|
||||
super().__init__(model, context_builder, tokenizer)
|
||||
|
||||
self.context_builder = context_builder
|
||||
self.token_encoder = token_encoder
|
||||
self.tokenizer = tokenizer or get_tokenizer()
|
||||
self.query_state = query_state or QueryState()
|
||||
self.primer = DRIFTPrimer(
|
||||
config=self.context_builder.config,
|
||||
chat_model=model,
|
||||
token_encoder=token_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
self.callbacks = callbacks or []
|
||||
self.local_search = self.init_local_search()
|
||||
@ -100,7 +100,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
model=self.model,
|
||||
system_prompt=self.context_builder.local_system_prompt,
|
||||
context_builder=self.context_builder.local_mixed_context,
|
||||
token_encoder=self.token_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
model_params=model_params,
|
||||
context_builder_params=local_context_params,
|
||||
response_type="multiple paragraphs",
|
||||
@ -392,10 +392,10 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
reduced_response = model_response.output.content
|
||||
|
||||
llm_calls["reduce"] = 1
|
||||
prompt_tokens["reduce"] = num_tokens(
|
||||
search_prompt, self.token_encoder
|
||||
) + num_tokens(query, self.token_encoder)
|
||||
output_tokens["reduce"] = num_tokens(reduced_response, self.token_encoder)
|
||||
prompt_tokens["reduce"] = len(self.tokenizer.encode(search_prompt)) + len(
|
||||
self.tokenizer.encode(query)
|
||||
)
|
||||
output_tokens["reduce"] = len(self.tokenizer.encode(reduced_response))
|
||||
|
||||
return reduced_response
|
||||
|
||||
|
||||
@ -5,8 +5,6 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
|
||||
from graphrag.data_model.community import Community
|
||||
from graphrag.data_model.community_report import CommunityReport
|
||||
from graphrag.data_model.entity import Entity
|
||||
@ -21,6 +19,8 @@ from graphrag.query.context_builder.dynamic_community_selection import (
|
||||
DynamicCommunitySelection,
|
||||
)
|
||||
from graphrag.query.structured_search.base import GlobalContextBuilder
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
|
||||
class GlobalCommunityContext(GlobalContextBuilder):
|
||||
@ -31,14 +31,14 @@ class GlobalCommunityContext(GlobalContextBuilder):
|
||||
community_reports: list[CommunityReport],
|
||||
communities: list[Community],
|
||||
entities: list[Entity] | None = None,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
dynamic_community_selection: bool = False,
|
||||
dynamic_community_selection_kwargs: dict[str, Any] | None = None,
|
||||
random_state: int = 86,
|
||||
):
|
||||
self.community_reports = community_reports
|
||||
self.entities = entities
|
||||
self.token_encoder = token_encoder
|
||||
self.tokenizer = tokenizer or get_tokenizer()
|
||||
self.dynamic_community_selection = None
|
||||
if dynamic_community_selection and isinstance(
|
||||
dynamic_community_selection_kwargs, dict
|
||||
@ -47,7 +47,7 @@ class GlobalCommunityContext(GlobalContextBuilder):
|
||||
community_reports=community_reports,
|
||||
communities=communities,
|
||||
model=dynamic_community_selection_kwargs.pop("model"),
|
||||
token_encoder=dynamic_community_selection_kwargs.pop("token_encoder"),
|
||||
tokenizer=dynamic_community_selection_kwargs.pop("tokenizer"),
|
||||
**dynamic_community_selection_kwargs,
|
||||
)
|
||||
self.random_state = random_state
|
||||
@ -103,7 +103,7 @@ class GlobalCommunityContext(GlobalContextBuilder):
|
||||
community_context, community_context_data = build_community_context(
|
||||
community_reports=community_reports,
|
||||
entities=self.entities,
|
||||
token_encoder=self.token_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
use_community_summary=use_community_summary,
|
||||
column_delimiter=column_delimiter,
|
||||
shuffle_data=shuffle_data,
|
||||
|
||||
@ -12,7 +12,6 @@ from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
|
||||
from graphrag.callbacks.query_callbacks import QueryCallbacks
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
@ -30,8 +29,9 @@ from graphrag.query.context_builder.builders import GlobalContextBuilder
|
||||
from graphrag.query.context_builder.conversation_history import (
|
||||
ConversationHistory,
|
||||
)
|
||||
from graphrag.query.llm.text_utils import num_tokens, try_parse_json_object
|
||||
from graphrag.query.llm.text_utils import try_parse_json_object
|
||||
from graphrag.query.structured_search.base import BaseSearch, SearchResult
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -52,7 +52,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
self,
|
||||
model: ChatModel,
|
||||
context_builder: GlobalContextBuilder,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
map_system_prompt: str | None = None,
|
||||
reduce_system_prompt: str | None = None,
|
||||
response_type: str = "multiple paragraphs",
|
||||
@ -71,7 +71,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
super().__init__(
|
||||
model=model,
|
||||
context_builder=context_builder,
|
||||
token_encoder=token_encoder,
|
||||
tokenizer=tokenizer,
|
||||
context_builder_params=context_builder_params,
|
||||
)
|
||||
self.map_system_prompt = map_system_prompt or MAP_SYSTEM_PROMPT
|
||||
@ -247,8 +247,8 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
context_text=context_data,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
output_tokens=num_tokens(search_response, self.token_encoder),
|
||||
prompt_tokens=len(self.tokenizer.encode(search_prompt)),
|
||||
output_tokens=len(self.tokenizer.encode(search_response)),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
@ -259,7 +259,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
context_text=context_data,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
prompt_tokens=len(self.tokenizer.encode(search_prompt)),
|
||||
output_tokens=0,
|
||||
)
|
||||
|
||||
@ -361,13 +361,12 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
formatted_response_data.append(point["answer"]) # type: ignore
|
||||
formatted_response_text = "\n".join(formatted_response_data)
|
||||
if (
|
||||
total_tokens
|
||||
+ num_tokens(formatted_response_text, self.token_encoder)
|
||||
total_tokens + len(self.tokenizer.encode(formatted_response_text))
|
||||
> self.max_data_tokens
|
||||
):
|
||||
break
|
||||
data.append(formatted_response_text)
|
||||
total_tokens += num_tokens(formatted_response_text, self.token_encoder)
|
||||
total_tokens += len(self.tokenizer.encode(formatted_response_text))
|
||||
text_data = "\n\n".join(data)
|
||||
|
||||
search_prompt = self.reduce_system_prompt.format(
|
||||
@ -398,8 +397,8 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
context_text=text_data,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
output_tokens=num_tokens(search_response, self.token_encoder),
|
||||
prompt_tokens=len(self.tokenizer.encode(search_prompt)),
|
||||
output_tokens=len(self.tokenizer.encode(search_response)),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Exception in reduce_response")
|
||||
@ -409,7 +408,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
context_text=text_data,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
prompt_tokens=len(self.tokenizer.encode(search_prompt)),
|
||||
output_tokens=0,
|
||||
)
|
||||
|
||||
@ -467,12 +466,12 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
]
|
||||
formatted_response_text = "\n".join(formatted_response_data)
|
||||
if (
|
||||
total_tokens + num_tokens(formatted_response_text, self.token_encoder)
|
||||
total_tokens + len(self.tokenizer.encode(formatted_response_text))
|
||||
> self.max_data_tokens
|
||||
):
|
||||
break
|
||||
data.append(formatted_response_text)
|
||||
total_tokens += num_tokens(formatted_response_text, self.token_encoder)
|
||||
total_tokens += len(self.tokenizer.encode(formatted_response_text))
|
||||
text_data = "\n\n".join(data)
|
||||
|
||||
search_prompt = self.reduce_system_prompt.format(
|
||||
|
||||
@ -7,7 +7,6 @@ from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
|
||||
from graphrag.data_model.community_report import CommunityReport
|
||||
from graphrag.data_model.covariate import Covariate
|
||||
@ -40,8 +39,9 @@ from graphrag.query.input.retrieval.community_reports import (
|
||||
get_candidate_communities,
|
||||
)
|
||||
from graphrag.query.input.retrieval.text_units import get_candidate_text_units
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.query.structured_search.base import LocalContextBuilder
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -59,7 +59,7 @@ class LocalSearchMixedContext(LocalContextBuilder):
|
||||
community_reports: list[CommunityReport] | None = None,
|
||||
relationships: list[Relationship] | None = None,
|
||||
covariates: dict[str, list[Covariate]] | None = None,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
embedding_vectorstore_key: str = EntityVectorStoreKey.ID,
|
||||
):
|
||||
if community_reports is None:
|
||||
@ -81,7 +81,7 @@ class LocalSearchMixedContext(LocalContextBuilder):
|
||||
self.covariates = covariates
|
||||
self.entity_text_embeddings = entity_text_embeddings
|
||||
self.text_embedder = text_embedder
|
||||
self.token_encoder = token_encoder
|
||||
self.tokenizer = tokenizer or get_tokenizer()
|
||||
self.embedding_vectorstore_key = embedding_vectorstore_key
|
||||
|
||||
def filter_by_entity_keys(self, entity_keys: list[int] | list[str]):
|
||||
@ -167,8 +167,8 @@ class LocalSearchMixedContext(LocalContextBuilder):
|
||||
if conversation_history_context.strip() != "":
|
||||
final_context.append(conversation_history_context)
|
||||
final_context_data = conversation_history_context_data
|
||||
max_context_tokens = max_context_tokens - num_tokens(
|
||||
conversation_history_context, self.token_encoder
|
||||
max_context_tokens = max_context_tokens - len(
|
||||
self.tokenizer.encode(conversation_history_context)
|
||||
)
|
||||
|
||||
# build community context
|
||||
@ -264,7 +264,7 @@ class LocalSearchMixedContext(LocalContextBuilder):
|
||||
|
||||
context_text, context_data = build_community_context(
|
||||
community_reports=selected_communities,
|
||||
token_encoder=self.token_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
use_community_summary=use_community_summary,
|
||||
column_delimiter=column_delimiter,
|
||||
shuffle_data=False,
|
||||
@ -344,7 +344,7 @@ class LocalSearchMixedContext(LocalContextBuilder):
|
||||
|
||||
context_text, context_data = build_text_unit_context(
|
||||
text_units=selected_text_units,
|
||||
token_encoder=self.token_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
max_context_tokens=max_context_tokens,
|
||||
shuffle_data=False,
|
||||
context_name=context_name,
|
||||
@ -390,14 +390,14 @@ class LocalSearchMixedContext(LocalContextBuilder):
|
||||
# build entity context
|
||||
entity_context, entity_context_data = build_entity_context(
|
||||
selected_entities=selected_entities,
|
||||
token_encoder=self.token_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
max_context_tokens=max_context_tokens,
|
||||
column_delimiter=column_delimiter,
|
||||
include_entity_rank=include_entity_rank,
|
||||
rank_description=rank_description,
|
||||
context_name="Entities",
|
||||
)
|
||||
entity_tokens = num_tokens(entity_context, self.token_encoder)
|
||||
entity_tokens = len(self.tokenizer.encode(entity_context))
|
||||
|
||||
# build relationship-covariate context
|
||||
added_entities = []
|
||||
@ -417,7 +417,7 @@ class LocalSearchMixedContext(LocalContextBuilder):
|
||||
) = build_relationship_context(
|
||||
selected_entities=added_entities,
|
||||
relationships=list(self.relationships.values()),
|
||||
token_encoder=self.token_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
max_context_tokens=max_context_tokens,
|
||||
column_delimiter=column_delimiter,
|
||||
top_k_relationships=top_k_relationships,
|
||||
@ -427,8 +427,8 @@ class LocalSearchMixedContext(LocalContextBuilder):
|
||||
)
|
||||
current_context.append(relationship_context)
|
||||
current_context_data["relationships"] = relationship_context_data
|
||||
total_tokens = entity_tokens + num_tokens(
|
||||
relationship_context, self.token_encoder
|
||||
total_tokens = entity_tokens + len(
|
||||
self.tokenizer.encode(relationship_context)
|
||||
)
|
||||
|
||||
# build covariate context
|
||||
@ -436,12 +436,12 @@ class LocalSearchMixedContext(LocalContextBuilder):
|
||||
covariate_context, covariate_context_data = build_covariates_context(
|
||||
selected_entities=added_entities,
|
||||
covariates=self.covariates[covariate],
|
||||
token_encoder=self.token_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
max_context_tokens=max_context_tokens,
|
||||
column_delimiter=column_delimiter,
|
||||
context_name=covariate,
|
||||
)
|
||||
total_tokens += num_tokens(covariate_context, self.token_encoder)
|
||||
total_tokens += len(self.tokenizer.encode(covariate_context))
|
||||
current_context.append(covariate_context)
|
||||
current_context_data[covariate.lower()] = covariate_context_data
|
||||
|
||||
|
||||
@ -8,8 +8,6 @@ import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
|
||||
from graphrag.callbacks.query_callbacks import QueryCallbacks
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.prompts.query.local_search_system_prompt import (
|
||||
@ -19,8 +17,8 @@ from graphrag.query.context_builder.builders import LocalContextBuilder
|
||||
from graphrag.query.context_builder.conversation_history import (
|
||||
ConversationHistory,
|
||||
)
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.query.structured_search.base import BaseSearch, SearchResult
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -32,7 +30,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
self,
|
||||
model: ChatModel,
|
||||
context_builder: LocalContextBuilder,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
system_prompt: str | None = None,
|
||||
response_type: str = "multiple paragraphs",
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
@ -42,7 +40,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
super().__init__(
|
||||
model=model,
|
||||
context_builder=context_builder,
|
||||
token_encoder=token_encoder,
|
||||
tokenizer=tokenizer,
|
||||
model_params=model_params,
|
||||
context_builder_params=context_builder_params or {},
|
||||
)
|
||||
@ -100,8 +98,8 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
callback.on_llm_new_token(response)
|
||||
|
||||
llm_calls["response"] = 1
|
||||
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
|
||||
output_tokens["response"] = num_tokens(full_response, self.token_encoder)
|
||||
prompt_tokens["response"] = len(self.tokenizer.encode(search_prompt))
|
||||
output_tokens["response"] = len(self.tokenizer.encode(full_response))
|
||||
|
||||
for callback in self.callbacks:
|
||||
callback.on_context(context_result.context_records)
|
||||
@ -127,7 +125,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
context_text=context_result.context_chunks,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
prompt_tokens=len(self.tokenizer.encode(search_prompt)),
|
||||
output_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
4
graphrag/tokenizer/__init__.py
Normal file
4
graphrag/tokenizer/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""GraphRAG tokenizer."""
|
||||
45
graphrag/tokenizer/get_tokenizer.py
Normal file
45
graphrag/tokenizer/get_tokenizer.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Get Tokenizer."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from graphrag.config.defaults import ENCODING_MODEL
|
||||
from graphrag.tokenizer.litellm_tokenizer import LitellmTokenizer
|
||||
from graphrag.tokenizer.tiktoken_tokenizer import TiktokenTokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
model_config: "LanguageModelConfig | None" = None,
|
||||
encoding_model: str = ENCODING_MODEL,
|
||||
) -> Tokenizer:
|
||||
"""
|
||||
Get the tokenizer for the given model configuration or fallback to a tiktoken based tokenizer.
|
||||
|
||||
Args
|
||||
----
|
||||
model_config: LanguageModelConfig, optional
|
||||
The model configuration. If not provided or model_config.encoding_model is manually set,
|
||||
use a tiktoken based tokenizer. Otherwise, use a LitellmTokenizer based on the model name.
|
||||
LiteLLM supports token encoding/decoding for the range of models it supports.
|
||||
encoding_model: str, optional
|
||||
A tiktoken encoding model to use if no model configuration is provided. Only used if a
|
||||
model configuration is not provided.
|
||||
|
||||
Returns
|
||||
-------
|
||||
An instance of a Tokenizer.
|
||||
"""
|
||||
if model_config is not None:
|
||||
if model_config.encoding_model.strip() != "":
|
||||
# User has manually specified a tiktoken encoding model to use for the provided model configuration.
|
||||
return TiktokenTokenizer(encoding_name=model_config.encoding_model)
|
||||
|
||||
return LitellmTokenizer(model_name=model_config.model)
|
||||
|
||||
return TiktokenTokenizer(encoding_name=encoding_model)
|
||||
47
graphrag/tokenizer/litellm_tokenizer.py
Normal file
47
graphrag/tokenizer/litellm_tokenizer.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM Tokenizer."""
|
||||
|
||||
from litellm import decode, encode # type: ignore
|
||||
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
|
||||
class LitellmTokenizer(Tokenizer):
|
||||
"""LiteLLM Tokenizer."""
|
||||
|
||||
def __init__(self, model_name: str) -> None:
|
||||
"""Initialize the LiteLLM Tokenizer.
|
||||
|
||||
Args
|
||||
----
|
||||
model_name (str): The name of the LiteLLM model to use for tokenization.
|
||||
"""
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, text: str) -> list[int]:
|
||||
"""Encode the given text into a list of tokens.
|
||||
|
||||
Args
|
||||
----
|
||||
text (str): The input text to encode.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[int]: A list of tokens representing the encoded text.
|
||||
"""
|
||||
return encode(model=self.model_name, text=text)
|
||||
|
||||
def decode(self, tokens: list[int]) -> str:
|
||||
"""Decode a list of tokens back into a string.
|
||||
|
||||
Args
|
||||
----
|
||||
tokens (list[int]): A list of tokens to decode.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str: The decoded string from the list of tokens.
|
||||
"""
|
||||
return decode(model=self.model_name, tokens=tokens)
|
||||
47
graphrag/tokenizer/tiktoken_tokenizer.py
Normal file
47
graphrag/tokenizer/tiktoken_tokenizer.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Tiktoken Tokenizer."""
|
||||
|
||||
import tiktoken
|
||||
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
|
||||
class TiktokenTokenizer(Tokenizer):
|
||||
"""Tiktoken Tokenizer."""
|
||||
|
||||
def __init__(self, encoding_name: str) -> None:
|
||||
"""Initialize the Tiktoken Tokenizer.
|
||||
|
||||
Args
|
||||
----
|
||||
encoding_name (str): The name of the Tiktoken encoding to use for tokenization.
|
||||
"""
|
||||
self.encoding = tiktoken.get_encoding(encoding_name)
|
||||
|
||||
def encode(self, text: str) -> list[int]:
|
||||
"""Encode the given text into a list of tokens.
|
||||
|
||||
Args
|
||||
----
|
||||
text (str): The input text to encode.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[int]: A list of tokens representing the encoded text.
|
||||
"""
|
||||
return self.encoding.encode(text)
|
||||
|
||||
def decode(self, tokens: list[int]) -> str:
|
||||
"""Decode a list of tokens back into a string.
|
||||
|
||||
Args
|
||||
----
|
||||
tokens (list[int]): A list of tokens to decode.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str: The decoded string from the list of tokens.
|
||||
"""
|
||||
return self.encoding.decode(tokens)
|
||||
53
graphrag/tokenizer/tokenizer.py
Normal file
53
graphrag/tokenizer/tokenizer.py
Normal file
@ -0,0 +1,53 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Tokenizer Abstract Base Class."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Tokenizer(ABC):
|
||||
"""Tokenizer Abstract Base Class."""
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, text: str) -> list[int]:
|
||||
"""Encode the given text into a list of tokens.
|
||||
|
||||
Args
|
||||
----
|
||||
text (str): The input text to encode.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[int]: A list of tokens representing the encoded text.
|
||||
"""
|
||||
msg = "The encode method must be implemented by subclasses."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, tokens: list[int]) -> str:
|
||||
"""Decode a list of tokens back into a string.
|
||||
|
||||
Args
|
||||
----
|
||||
tokens (list[int]): A list of tokens to decode.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str: The decoded string from the list of tokens.
|
||||
"""
|
||||
msg = "The decode method must be implemented by subclasses."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def num_tokens(self, text: str) -> int:
|
||||
"""Return the number of tokens in the given text.
|
||||
|
||||
Args
|
||||
----
|
||||
text (str): The input text to analyze.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int: The number of tokens in the input text.
|
||||
"""
|
||||
return len(self.encode(text))
|
||||
@ -43,7 +43,7 @@ dependencies = [
|
||||
"json-repair>=0.30.3",
|
||||
"openai>=1.68.0",
|
||||
"nltk==3.9.1",
|
||||
"tiktoken>=0.9.0",
|
||||
"tiktoken>=0.11.0",
|
||||
# Data-Science
|
||||
"numpy>=1.25.2",
|
||||
"graspologic>=3.4.1",
|
||||
@ -66,6 +66,7 @@ dependencies = [
|
||||
"tqdm>=4.67.1",
|
||||
"textblob>=0.18.0.post0",
|
||||
"spacy>=3.8.4",
|
||||
"litellm>=1.77.1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@ -9,7 +9,7 @@ import tiktoken
|
||||
|
||||
from graphrag.index.text_splitting.text_splitting import (
|
||||
NoopTextSplitter,
|
||||
Tokenizer,
|
||||
TokenChunkerOptions,
|
||||
TokenTextSplitter,
|
||||
split_multiple_texts_on_tokens,
|
||||
split_single_text_on_tokens,
|
||||
@ -64,7 +64,7 @@ def test_split_text_large_input(mock_split):
|
||||
|
||||
|
||||
@mock.patch("graphrag.index.text_splitting.text_splitting.split_single_text_on_tokens")
|
||||
@mock.patch("graphrag.index.text_splitting.text_splitting.Tokenizer")
|
||||
@mock.patch("graphrag.index.text_splitting.text_splitting.TokenChunkerOptions")
|
||||
def test_token_text_splitter(mock_tokenizer, mock_split_text):
|
||||
text = "chunk1 chunk2 chunk3"
|
||||
expected_chunks = ["chunk1", "chunk2", "chunk3"]
|
||||
@ -80,42 +80,10 @@ def test_token_text_splitter(mock_tokenizer, mock_split_text):
|
||||
mock_split_text.assert_called_once_with(text=text, tokenizer=mocked_tokenizer)
|
||||
|
||||
|
||||
def test_encode_basic():
|
||||
splitter = TokenTextSplitter()
|
||||
result = splitter.encode("abc def")
|
||||
|
||||
assert result == [13997, 711], "Encoding failed to return expected tokens"
|
||||
|
||||
|
||||
def test_num_tokens_empty_input():
|
||||
splitter = TokenTextSplitter()
|
||||
result = splitter.num_tokens("")
|
||||
|
||||
assert result == 0, "Token count for empty input should be 0"
|
||||
|
||||
|
||||
def test_model_name():
|
||||
splitter = TokenTextSplitter(model_name="gpt-4o")
|
||||
result = splitter.encode("abc def")
|
||||
|
||||
assert result == [26682, 1056], "Encoding failed to return expected tokens"
|
||||
|
||||
|
||||
@mock.patch("tiktoken.encoding_for_model", side_effect=KeyError)
|
||||
@mock.patch("tiktoken.get_encoding")
|
||||
def test_model_name_exception(mock_get_encoding, mock_encoding_for_model):
|
||||
mock_get_encoding.return_value = mock.MagicMock()
|
||||
|
||||
TokenTextSplitter(model_name="mock_model", encoding_name="mock_encoding")
|
||||
|
||||
mock_get_encoding.assert_called_once_with("mock_encoding")
|
||||
mock_encoding_for_model.assert_called_once_with("mock_model")
|
||||
|
||||
|
||||
def test_split_single_text_on_tokens():
|
||||
text = "This is a test text, meaning to be taken seriously by this test only."
|
||||
mocked_tokenizer = MockTokenizer()
|
||||
tokenizer = Tokenizer(
|
||||
tokenizer = TokenChunkerOptions(
|
||||
chunk_overlap=5,
|
||||
tokens_per_chunk=10,
|
||||
decode=mocked_tokenizer.decode,
|
||||
@ -150,7 +118,7 @@ def test_split_multiple_texts_on_tokens():
|
||||
|
||||
mocked_tokenizer = MockTokenizer()
|
||||
mock_tick = MagicMock()
|
||||
tokenizer = Tokenizer(
|
||||
tokenizer = TokenChunkerOptions(
|
||||
chunk_overlap=5,
|
||||
tokens_per_chunk=10,
|
||||
decode=mocked_tokenizer.decode,
|
||||
@ -173,7 +141,7 @@ def test_split_single_text_on_tokens_no_overlap():
|
||||
def decode(tokens: list[int]) -> str:
|
||||
return enc.decode(tokens)
|
||||
|
||||
tokenizer = Tokenizer(
|
||||
tokenizer = TokenChunkerOptions(
|
||||
chunk_overlap=1,
|
||||
tokens_per_chunk=2,
|
||||
decode=decode,
|
||||
|
||||
2
tests/unit/litellm_services/__init__.py
Normal file
2
tests/unit/litellm_services/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
368
tests/unit/litellm_services/test_rate_limiter.py
Normal file
368
tests/unit/litellm_services/test_rate_limiter.py
Normal file
@ -0,0 +1,368 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Test LiteLLM Rate Limiter."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from math import ceil
|
||||
from queue import Queue
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter import (
|
||||
RateLimiter,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter_factory import (
|
||||
RateLimiterFactory,
|
||||
)
|
||||
from tests.unit.litellm_services.utils import (
|
||||
assert_max_num_values_per_period,
|
||||
assert_stagger,
|
||||
bin_time_intervals,
|
||||
)
|
||||
|
||||
rate_limiter_factory = RateLimiterFactory()
|
||||
|
||||
_period_in_seconds = 1
|
||||
_rpm = 4
|
||||
_tpm = 75
|
||||
_tokens_per_request = 25
|
||||
_stagger = _period_in_seconds / _rpm
|
||||
_num_requests = 10
|
||||
|
||||
|
||||
def test_binning():
|
||||
"""Test binning timings into 1-second intervals."""
|
||||
values = [0.1, 0.2, 0.3, 0.4, 1.1, 1.2, 1.3, 1.4, 5.1]
|
||||
binned_values = bin_time_intervals(values, 1)
|
||||
assert binned_values == [
|
||||
[0.1, 0.2, 0.3, 0.4],
|
||||
[1.1, 1.2, 1.3, 1.4],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[5.1],
|
||||
]
|
||||
|
||||
|
||||
def test_rate_limiter_validation():
|
||||
"""Test that the rate limiter can be created with valid parameters."""
|
||||
|
||||
# Valid parameters
|
||||
rate_limiter = rate_limiter_factory.create(
|
||||
strategy="static", rpm=60, tpm=10000, period_in_seconds=60
|
||||
)
|
||||
assert rate_limiter is not None
|
||||
|
||||
# Invalid strategy
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Strategy 'invalid_strategy' is not registered.",
|
||||
):
|
||||
rate_limiter_factory.create(strategy="invalid_strategy", rpm=60, tpm=10000)
|
||||
|
||||
# Both rpm and tpm are None
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Both TPM and RPM cannot be None \(disabled\), one or both must be set to a positive integer.",
|
||||
):
|
||||
rate_limiter_factory.create(strategy="static")
|
||||
|
||||
# Invalid rpm
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"RPM and TPM must be either None \(disabled\) or positive integers.",
|
||||
):
|
||||
rate_limiter_factory.create(strategy="static", rpm=-10)
|
||||
|
||||
# Invalid tpm
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"RPM and TPM must be either None \(disabled\) or positive integers.",
|
||||
):
|
||||
rate_limiter_factory.create(strategy="static", tpm=-10)
|
||||
|
||||
# Invalid period_in_seconds
|
||||
with pytest.raises(
|
||||
ValueError, match=r"Period in seconds must be a positive integer."
|
||||
):
|
||||
rate_limiter_factory.create(strategy="static", rpm=10, period_in_seconds=-10)
|
||||
|
||||
|
||||
def test_rpm():
|
||||
"""Test that the rate limiter enforces RPM limits."""
|
||||
rate_limiter = rate_limiter_factory.create(
|
||||
strategy="static", rpm=_rpm, period_in_seconds=_period_in_seconds
|
||||
)
|
||||
|
||||
time_values: list[float] = []
|
||||
start_time = time.time()
|
||||
for _ in range(_num_requests):
|
||||
with rate_limiter.acquire(token_count=_tokens_per_request):
|
||||
time_values.append(time.time() - start_time)
|
||||
|
||||
assert len(time_values) == _num_requests
|
||||
binned_time_values = bin_time_intervals(time_values, _period_in_seconds)
|
||||
|
||||
"""
|
||||
With _num_requests = 10 and _rpm = 4, we expect the requests to be
|
||||
distributed across ceil(10/4) = 3 bins:
|
||||
with a stagger of 1/4 = 0.25 seconds between requests.
|
||||
"""
|
||||
|
||||
expected_num_bins = ceil(_num_requests / _rpm)
|
||||
assert len(binned_time_values) == expected_num_bins
|
||||
|
||||
assert_max_num_values_per_period(binned_time_values, _rpm)
|
||||
assert_stagger(time_values, _stagger)
|
||||
|
||||
|
||||
def test_tpm():
|
||||
"""Test that the rate limiter enforces TPM limits."""
|
||||
rate_limiter = rate_limiter_factory.create(
|
||||
strategy="static", tpm=_tpm, period_in_seconds=_period_in_seconds
|
||||
)
|
||||
|
||||
time_values: list[float] = []
|
||||
start_time = time.time()
|
||||
for _ in range(_num_requests):
|
||||
with rate_limiter.acquire(token_count=_tokens_per_request):
|
||||
time_values.append(time.time() - start_time)
|
||||
|
||||
assert len(time_values) == _num_requests
|
||||
binned_time_values = bin_time_intervals(time_values, _period_in_seconds)
|
||||
|
||||
"""
|
||||
With _num_requests = 10, _tpm = 75 and _tokens_per_request = 25, we expect the requests to be
|
||||
distributed across ceil( (10 * 25) / 75) ) = 4 bins
|
||||
and max requests per bin = (75 / 25) = 3 requests per bin.
|
||||
"""
|
||||
|
||||
expected_num_bins = ceil((_num_requests * _tokens_per_request) / _tpm)
|
||||
assert len(binned_time_values) == expected_num_bins
|
||||
|
||||
max_num_of_requests_per_bin = _tpm // _tokens_per_request
|
||||
assert_max_num_values_per_period(binned_time_values, max_num_of_requests_per_bin)
|
||||
|
||||
|
||||
def test_token_in_request_exceeds_tpm():
|
||||
"""Test that the rate limiter allows for requests that use more tokens than the TPM.
|
||||
|
||||
A rate limiter could be configured with a tpm of 1000 but a request may use 2000 tokens,
|
||||
greater than the tpm limit but still below the context window limit of the underlying model.
|
||||
In this case, the request should still be allowed to proceed but may take up its own rate limit bin.
|
||||
"""
|
||||
rate_limiter = rate_limiter_factory.create(
|
||||
strategy="static", tpm=_tpm, period_in_seconds=_period_in_seconds
|
||||
)
|
||||
|
||||
time_values: list[float] = []
|
||||
start_time = time.time()
|
||||
for _ in range(2):
|
||||
with rate_limiter.acquire(token_count=_tpm * 2):
|
||||
time_values.append(time.time() - start_time)
|
||||
|
||||
assert len(time_values) == 2
|
||||
binned_time_values = bin_time_intervals(time_values, _period_in_seconds)
|
||||
|
||||
"""
|
||||
Since each request exceeds the tpm, we expect each request to still be fired off but to be in its own bin
|
||||
"""
|
||||
|
||||
assert len(binned_time_values) == 2
|
||||
|
||||
assert_max_num_values_per_period(binned_time_values, 1)
|
||||
|
||||
|
||||
def test_rpm_and_tpm_with_rpm_as_limiting_factor():
|
||||
"""Test that the rate limiter enforces RPM and TPM limits."""
|
||||
rate_limiter = rate_limiter_factory.create(
|
||||
strategy="static", rpm=_rpm, tpm=_tpm, period_in_seconds=_period_in_seconds
|
||||
)
|
||||
|
||||
time_values: list[float] = []
|
||||
start_time = time.time()
|
||||
for _ in range(_num_requests):
|
||||
# Use 0 tokens per request to simulate RPM as the limiting factor
|
||||
with rate_limiter.acquire(token_count=0):
|
||||
time_values.append(time.time() - start_time)
|
||||
|
||||
assert len(time_values) == _num_requests
|
||||
binned_time_values = bin_time_intervals(time_values, _period_in_seconds)
|
||||
|
||||
"""
|
||||
With _num_requests = 10 and _rpm = 4, we expect the requests to be
|
||||
distributed across ceil(10/4) = 3 bins:
|
||||
with a stagger of 1/4 = 0.25 seconds between requests.
|
||||
"""
|
||||
|
||||
expected_num_bins = ceil(_num_requests / _rpm)
|
||||
assert len(binned_time_values) == expected_num_bins
|
||||
|
||||
assert_max_num_values_per_period(binned_time_values, _rpm)
|
||||
assert_stagger(time_values, _stagger)
|
||||
|
||||
|
||||
def test_rpm_and_tpm_with_tpm_as_limiting_factor():
|
||||
"""Test that the rate limiter enforces TPM limits."""
|
||||
rate_limiter = rate_limiter_factory.create(
|
||||
strategy="static", rpm=_rpm, tpm=_tpm, period_in_seconds=_period_in_seconds
|
||||
)
|
||||
|
||||
time_values: list[float] = []
|
||||
start_time = time.time()
|
||||
for _ in range(_num_requests):
|
||||
with rate_limiter.acquire(token_count=_tokens_per_request):
|
||||
time_values.append(time.time() - start_time)
|
||||
|
||||
assert len(time_values) == _num_requests
|
||||
binned_time_values = bin_time_intervals(time_values, _period_in_seconds)
|
||||
|
||||
"""
|
||||
With _num_requests = 10, _tpm = 75 and _tokens_per_request = 25, we expect the requests to be
|
||||
distributed across ceil( (10 * 25) / 75) ) = 4 bins
|
||||
and max requests per bin = (75 / 25) = 3 requests per bin.
|
||||
"""
|
||||
|
||||
expected_num_bins = ceil((_num_requests * _tokens_per_request) / _tpm)
|
||||
assert len(binned_time_values) == expected_num_bins
|
||||
|
||||
max_num_of_requests_per_bin = _tpm // _tokens_per_request
|
||||
assert_max_num_values_per_period(binned_time_values, max_num_of_requests_per_bin)
|
||||
assert_stagger(time_values, _stagger)
|
||||
|
||||
|
||||
def _run_rate_limiter(
|
||||
rate_limiter: RateLimiter,
|
||||
# Acquire cost
|
||||
input_queue: Queue[int | None],
|
||||
# time value
|
||||
output_queue: Queue[float | None],
|
||||
):
|
||||
while True:
|
||||
token_count = input_queue.get()
|
||||
if token_count is None:
|
||||
break
|
||||
with rate_limiter.acquire(token_count=token_count):
|
||||
output_queue.put(time.time())
|
||||
|
||||
|
||||
def test_rpm_threaded():
|
||||
"""Test that the rate limiter enforces RPM limits in a threaded environment."""
|
||||
rate_limiter = rate_limiter_factory.create(
|
||||
strategy="static", rpm=_rpm, tpm=_tpm, period_in_seconds=_period_in_seconds
|
||||
)
|
||||
|
||||
input_queue: Queue[int | None] = Queue()
|
||||
output_queue: Queue[float | None] = Queue()
|
||||
|
||||
# Spin up threads for half the number of requests
|
||||
threads = [
|
||||
threading.Thread(
|
||||
target=_run_rate_limiter,
|
||||
args=(rate_limiter, input_queue, output_queue),
|
||||
)
|
||||
for _ in range(_num_requests // 2) # Create 5 threads
|
||||
]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
start_time = time.time()
|
||||
for _ in range(_num_requests):
|
||||
# Use 0 tokens per request to simulate RPM as the limiting factor
|
||||
input_queue.put(0)
|
||||
|
||||
# Signal threads to stop
|
||||
for _ in range(len(threads)):
|
||||
input_queue.put(None)
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
output_queue.put(None) # Signal end of output
|
||||
|
||||
time_values = []
|
||||
while True:
|
||||
time_value = output_queue.get()
|
||||
if time_value is None:
|
||||
break
|
||||
time_values.append(time_value - start_time)
|
||||
|
||||
time_values.sort()
|
||||
|
||||
assert len(time_values) == _num_requests
|
||||
binned_time_values = bin_time_intervals(time_values, _period_in_seconds)
|
||||
|
||||
"""
|
||||
With _num_requests = 10 and _rpm = 4, we expect the requests to be
|
||||
distributed across ceil(10/4) = 3 bins:
|
||||
with a stagger of 1/4 = 0.25 seconds between requests.
|
||||
"""
|
||||
|
||||
expected_num_bins = ceil(_num_requests / _rpm)
|
||||
assert len(binned_time_values) == expected_num_bins
|
||||
|
||||
assert_max_num_values_per_period(binned_time_values, _rpm)
|
||||
assert_stagger(time_values, _stagger)
|
||||
|
||||
|
||||
def test_tpm_threaded():
|
||||
"""Test that the rate limiter enforces TPM limits in a threaded environment."""
|
||||
rate_limiter = rate_limiter_factory.create(
|
||||
strategy="static", rpm=_rpm, tpm=_tpm, period_in_seconds=_period_in_seconds
|
||||
)
|
||||
|
||||
input_queue: Queue[int | None] = Queue()
|
||||
output_queue: Queue[float | None] = Queue()
|
||||
|
||||
# Spin up threads for half the number of requests
|
||||
threads = [
|
||||
threading.Thread(
|
||||
target=_run_rate_limiter,
|
||||
args=(rate_limiter, input_queue, output_queue),
|
||||
)
|
||||
for _ in range(_num_requests // 2) # Create 5 threads
|
||||
]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
start_time = time.time()
|
||||
for _ in range(_num_requests):
|
||||
input_queue.put(_tokens_per_request)
|
||||
|
||||
# Signal threads to stop
|
||||
for _ in range(len(threads)):
|
||||
input_queue.put(None)
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
output_queue.put(None) # Signal end of output
|
||||
|
||||
time_values = []
|
||||
while True:
|
||||
time_value = output_queue.get()
|
||||
if time_value is None:
|
||||
break
|
||||
time_values.append(time_value - start_time)
|
||||
|
||||
time_values.sort()
|
||||
|
||||
assert len(time_values) == _num_requests
|
||||
binned_time_values = bin_time_intervals(time_values, _period_in_seconds)
|
||||
|
||||
"""
|
||||
With _num_requests = 10, _tpm = 75 and _tokens_per_request = 25, we expect the requests to be
|
||||
distributed across ceil( (10 * 25) / 75) ) = 4 bins
|
||||
and max requests per bin = (75 / 25) = 3 requests per bin.
|
||||
"""
|
||||
|
||||
expected_num_bins = ceil((_num_requests * _tokens_per_request) / _tpm)
|
||||
assert len(binned_time_values) == expected_num_bins
|
||||
|
||||
max_num_of_requests_per_bin = _tpm // _tokens_per_request
|
||||
assert_max_num_values_per_period(binned_time_values, max_num_of_requests_per_bin)
|
||||
assert_stagger(time_values, _stagger)
|
||||
152
tests/unit/litellm_services/test_retries.py
Normal file
152
tests/unit/litellm_services/test_retries.py
Normal file
@ -0,0 +1,152 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Test LiteLLM Retries."""
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.language_model.providers.litellm.services.retry.retry_factory import (
|
||||
RetryFactory,
|
||||
)
|
||||
|
||||
retry_factory = RetryFactory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("strategy", "max_attempts", "max_retry_wait", "expected_time"),
|
||||
[
|
||||
(
|
||||
"native",
|
||||
3, # 3 retries
|
||||
0, # native retry does not adhere to max_retry_wait
|
||||
0, # immediate retry, expect 0 seconds elapsed time
|
||||
),
|
||||
(
|
||||
"exponential_backoff",
|
||||
3, # 3 retries
|
||||
0, # exponential retry does not adhere to max_retry_wait
|
||||
14, # (2^1 + jitter) + (2^2 + jitter) + (2^3 + jitter) = 2 + 4 + 8 + 3*jitter = 14 seconds min total runtime
|
||||
),
|
||||
(
|
||||
"random_wait",
|
||||
3, # 3 retries
|
||||
2, # random wait [0, 2] seconds
|
||||
0, # unpredictable, don't know what the total runtime will be
|
||||
),
|
||||
(
|
||||
"incremental_wait",
|
||||
3, # 3 retries
|
||||
3, # wait for a max of 3 seconds on a single retry.
|
||||
6, # Wait 3/3 * 1 on first retry, 3/3 * 2 on second, 3/3 * 3 on third, 1 + 2 + 3 = 6 seconds total runtime.
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_retries(
|
||||
strategy: str, max_attempts: int, max_retry_wait: int, expected_time: float
|
||||
) -> None:
|
||||
"""
|
||||
Test various retry strategies with various configurations.
|
||||
|
||||
Args
|
||||
----
|
||||
strategy: The retry strategy to use.
|
||||
max_attempts: The maximum number of retry attempts.
|
||||
max_retry_wait: The maximum wait time between retries.
|
||||
"""
|
||||
retry_service = retry_factory.create(
|
||||
strategy=strategy,
|
||||
max_attempts=max_attempts,
|
||||
max_retry_wait=max_retry_wait,
|
||||
)
|
||||
|
||||
retries = 0
|
||||
|
||||
def mock_func():
|
||||
nonlocal retries
|
||||
retries += 1
|
||||
msg = "Mock error for testing retries"
|
||||
raise ValueError(msg)
|
||||
|
||||
start_time = time.time()
|
||||
with pytest.raises(ValueError, match="Mock error for testing retries"):
|
||||
retry_service.retry(func=mock_func)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# subtract 1 from retries because the first call is not a retry
|
||||
assert retries - 1 == max_attempts, (
|
||||
f"Expected {max_attempts} retries, got {retries}"
|
||||
)
|
||||
assert elapsed_time >= expected_time, (
|
||||
f"Expected elapsed time >= {expected_time}, got {elapsed_time}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("strategy", "max_attempts", "max_retry_wait", "expected_time"),
|
||||
[
|
||||
(
|
||||
"native",
|
||||
3, # 3 retries
|
||||
0, # native retry does not adhere to max_retry_wait
|
||||
0, # immediate retry, expect 0 seconds elapsed time
|
||||
),
|
||||
(
|
||||
"exponential_backoff",
|
||||
3, # 3 retries
|
||||
0, # exponential retry does not adhere to max_retry_wait
|
||||
14, # (2^1 + jitter) + (2^2 + jitter) + (2^3 + jitter) = 2 + 4 + 8 + 3*jitter = 14 seconds min total runtime
|
||||
),
|
||||
(
|
||||
"random_wait",
|
||||
3, # 3 retries
|
||||
2, # random wait [0, 2] seconds
|
||||
0, # unpredictable, don't know what the total runtime will be
|
||||
),
|
||||
(
|
||||
"incremental_wait",
|
||||
3, # 3 retries
|
||||
3, # wait for a max of 3 seconds on a single retry.
|
||||
6, # Wait 3/3 * 1 on first retry, 3/3 * 2 on second, 3/3 * 3 on third, 1 + 2 + 3 = 6 seconds total runtime.
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_retries_async(
|
||||
strategy: str, max_attempts: int, max_retry_wait: int, expected_time: float
|
||||
) -> None:
|
||||
"""
|
||||
Test various retry strategies with various configurations.
|
||||
|
||||
Args
|
||||
----
|
||||
strategy: The retry strategy to use.
|
||||
max_attempts: The maximum number of retry attempts.
|
||||
max_retry_wait: The maximum wait time between retries.
|
||||
"""
|
||||
retry_service = retry_factory.create(
|
||||
strategy=strategy,
|
||||
max_attempts=max_attempts,
|
||||
max_retry_wait=max_retry_wait,
|
||||
)
|
||||
|
||||
retries = 0
|
||||
|
||||
async def mock_func(): # noqa: RUF029
|
||||
nonlocal retries
|
||||
retries += 1
|
||||
msg = "Mock error for testing retries"
|
||||
raise ValueError(msg)
|
||||
|
||||
start_time = time.time()
|
||||
with pytest.raises(ValueError, match="Mock error for testing retries"):
|
||||
await retry_service.aretry(func=mock_func)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# subtract 1 from retries because the first call is not a retry
|
||||
assert retries - 1 == max_attempts, (
|
||||
f"Expected {max_attempts} retries, got {retries}"
|
||||
)
|
||||
assert elapsed_time >= expected_time, (
|
||||
f"Expected elapsed time >= {expected_time}, got {elapsed_time}"
|
||||
)
|
||||
37
tests/unit/litellm_services/utils.py
Normal file
37
tests/unit/litellm_services/utils.py
Normal file
@ -0,0 +1,37 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM Test Utilities."""
|
||||
|
||||
|
||||
def bin_time_intervals(
|
||||
time_values: list[float], time_interval: int
|
||||
) -> list[list[float]]:
|
||||
"""Bin values."""
|
||||
bins: list[list[float]] = []
|
||||
|
||||
bin_number = 0
|
||||
for time_value in time_values:
|
||||
upper_bound = (bin_number * time_interval) + time_interval
|
||||
while time_value >= upper_bound:
|
||||
bin_number += 1
|
||||
upper_bound = (bin_number * time_interval) + time_interval
|
||||
while len(bins) <= bin_number:
|
||||
bins.append([])
|
||||
bins[bin_number].append(time_value)
|
||||
|
||||
return bins
|
||||
|
||||
|
||||
def assert_max_num_values_per_period(
|
||||
periods: list[list[float]], max_values_per_period: int
|
||||
):
|
||||
"""Assert the number of values per period."""
|
||||
for period in periods:
|
||||
assert len(period) <= max_values_per_period
|
||||
|
||||
|
||||
def assert_stagger(time_values: list[float], stagger: float):
|
||||
"""Assert stagger."""
|
||||
for i in range(1, len(time_values)):
|
||||
assert time_values[i] - time_values[i - 1] >= stagger
|
||||
18
tests/unit/utils/test_encoding.py
Normal file
18
tests/unit/utils/test_encoding.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
|
||||
|
||||
def test_encode_basic():
|
||||
tokenizer = get_tokenizer()
|
||||
result = tokenizer.encode("abc def")
|
||||
|
||||
assert result == [13997, 711], "Encoding failed to return expected tokens"
|
||||
|
||||
|
||||
def test_num_tokens_empty_input():
|
||||
tokenizer = get_tokenizer()
|
||||
result = len(tokenizer.encode(""))
|
||||
|
||||
assert result == 0, "Token count for empty input should be 0"
|
||||
Loading…
Reference in New Issue
Block a user