* Add LiteLLM chat and embedding model providers.

* Fix code review findings.

* Add litellm.

* Fix formatting.

* Update dictionary.

* Update litellm.

* Fix embedding.

* Remove manual use of tiktoken and replace with
Tokenizer interface. Adds support for encoding
and decoding the models supported by litellm.

* Update litellm.

* Configure litellm to drop unsupported params.

* Cleanup semversioner release notes.

* Add num_tokens util to Tokenizer interface.

* Update litellm service factories.

* Cleanup litellm chat/embedding model argument assignment.

* Update chat and embedding type field for litellm use and future migration away from fnllm.

* Flatten litellm service organization.

* Update litellm.

* Update litellm factory validation.

* Flatten litellm rate limit service organization.

* Update rate limiter - disable with None/null instead of 0.

* Fix usage of get_tokenizer.

* Update litellm service registrations.

* Add jitter to exponential retry.

* Update validation.

* Update validation.

* Add litellm request logging layer.

* Update cache key.

* Update defaults.

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
Derek Worthen 2025-09-22 12:55:14 -07:00 committed by GitHub
parent 82cd3b7df2
commit 2b70e4a4f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
66 changed files with 5269 additions and 1871 deletions

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Add LiteLLM chat and embedding model providers."
}

View File

@ -81,6 +81,7 @@ typer
spacy
kwargs
ollama
litellm
# Library Methods
iterrows
@ -103,6 +104,8 @@ isin
nocache
nbconvert
levelno
acompletion
aembedding
# HTML
nbsp

View File

@ -47,6 +47,7 @@ from graphrag.prompt_tune.generator.language import detect_language
from graphrag.prompt_tune.generator.persona import generate_persona
from graphrag.prompt_tune.loader.input import load_docs_in_chunks
from graphrag.prompt_tune.types import DocSelectionType
from graphrag.tokenizer.get_tokenizer import get_tokenizer
logger = logging.getLogger(__name__)
@ -166,7 +167,7 @@ async def generate_indexing_prompts(
examples=examples,
language=language,
json_mode=False, # config.llm.model_supports_json should be used, but these prompts are used in non-json mode by the index engine
encoding_model=extract_graph_llm_settings.encoding_model,
tokenizer=get_tokenizer(model_config=extract_graph_llm_settings),
max_token_count=max_tokens,
min_examples_required=min_examples_required,
)

View File

@ -3,6 +3,7 @@
"""Common default configuration values."""
from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from typing import ClassVar, Literal
@ -23,6 +24,25 @@ from graphrag.config.enums import (
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
EN_STOP_WORDS,
)
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter import (
RateLimiter,
)
from graphrag.language_model.providers.litellm.services.rate_limiter.static_rate_limiter import (
StaticRateLimiter,
)
from graphrag.language_model.providers.litellm.services.retry.exponential_retry import (
ExponentialRetry,
)
from graphrag.language_model.providers.litellm.services.retry.incremental_wait_retry import (
IncrementalWaitRetry,
)
from graphrag.language_model.providers.litellm.services.retry.native_wait_retry import (
NativeRetry,
)
from graphrag.language_model.providers.litellm.services.retry.random_wait_retry import (
RandomWaitRetry,
)
from graphrag.language_model.providers.litellm.services.retry.retry import Retry
DEFAULT_OUTPUT_BASE_DIR = "output"
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
@ -39,6 +59,18 @@ ENCODING_MODEL = "cl100k_base"
COGNITIVE_SERVICES_AUDIENCE = "https://cognitiveservices.azure.com/.default"
DEFAULT_RETRY_SERVICES: dict[str, Callable[..., Retry]] = {
"native": NativeRetry,
"exponential_backoff": ExponentialRetry,
"random_wait": RandomWaitRetry,
"incremental_wait": IncrementalWaitRetry,
}
DEFAULT_RATE_LIMITER_SERVICES: dict[str, Callable[..., RateLimiter]] = {
"static": StaticRateLimiter,
}
@dataclass
class BasicSearchDefaults:
"""Default values for basic search."""
@ -275,6 +307,7 @@ class LanguageModelDefaults:
api_key: None = None
auth_type: ClassVar[AuthType] = AuthType.APIKey
model_provider: str | None = None
encoding_model: str = ""
max_tokens: int | None = None
temperature: float = 0
@ -294,6 +327,7 @@ class LanguageModelDefaults:
model_supports_json: None = None
tokens_per_minute: Literal["auto"] = "auto"
requests_per_minute: Literal["auto"] = "auto"
rate_limit_strategy: str | None = "static"
retry_strategy: str = "native"
max_retries: int = 10
max_retry_wait: float = 10.0

View File

@ -86,10 +86,12 @@ class ModelType(str, Enum):
# Embeddings
OpenAIEmbedding = "openai_embedding"
AzureOpenAIEmbedding = "azure_openai_embedding"
Embedding = "embedding"
# Chat Completion
OpenAIChat = "openai_chat"
AzureOpenAIChat = "azure_openai_chat"
Chat = "chat"
# Debug
MockChat = "mock_chat"

View File

@ -37,6 +37,12 @@ from graphrag.config.models.summarize_descriptions_config import (
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
from graphrag.config.models.umap_config import UmapConfig
from graphrag.config.models.vector_store_config import VectorStoreConfig
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter_factory import (
RateLimiterFactory,
)
from graphrag.language_model.providers.litellm.services.retry.retry_factory import (
RetryFactory,
)
class GraphRagConfig(BaseModel):
@ -89,6 +95,47 @@ class GraphRagConfig(BaseModel):
if defs.DEFAULT_EMBEDDING_MODEL_ID not in self.models:
raise LanguageModelConfigMissingError(defs.DEFAULT_EMBEDDING_MODEL_ID)
def _validate_retry_services(self) -> None:
"""Validate the retry services configuration."""
retry_factory = RetryFactory()
for model_id, model in self.models.items():
if model.retry_strategy != "none":
if model.retry_strategy not in retry_factory:
msg = f"Retry strategy '{model.retry_strategy}' for model '{model_id}' is not registered. Available strategies: {', '.join(retry_factory.keys())}"
raise ValueError(msg)
_ = retry_factory.create(
strategy=model.retry_strategy,
max_attempts=model.max_retries,
max_retry_wait=model.max_retry_wait,
)
def _validate_rate_limiter_services(self) -> None:
"""Validate the rate limiter services configuration."""
rate_limiter_factory = RateLimiterFactory()
for model_id, model in self.models.items():
if model.rate_limit_strategy is not None:
if model.rate_limit_strategy not in rate_limiter_factory:
msg = f"Rate Limiter strategy '{model.rate_limit_strategy}' for model '{model_id}' is not registered. Available strategies: {', '.join(rate_limiter_factory.keys())}"
raise ValueError(msg)
rpm = (
model.requests_per_minute
if type(model.requests_per_minute) is int
else None
)
tpm = (
model.tokens_per_minute
if type(model.tokens_per_minute) is int
else None
)
if rpm is not None or tpm is not None:
_ = rate_limiter_factory.create(
strategy=model.rate_limit_strategy, rpm=rpm, tpm=tpm
)
input: InputConfig = Field(
description="The input configuration.", default=InputConfig()
)
@ -300,6 +347,11 @@ class GraphRagConfig(BaseModel):
raise ValueError(msg)
store.db_uri = str((Path(self.root_dir) / store.db_uri).resolve())
def _validate_factories(self) -> None:
"""Validate the factories used in the configuration."""
self._validate_retry_services()
self._validate_rate_limiter_services()
def get_language_model_config(self, model_id: str) -> LanguageModelConfig:
"""Get a model configuration by ID.
@ -360,4 +412,5 @@ class GraphRagConfig(BaseModel):
self._validate_multi_output_base_dirs()
self._validate_update_index_output_base_dir()
self._validate_vector_store_db_uri()
self._validate_factories()
return self

View File

@ -73,8 +73,11 @@ class LanguageModelConfig(BaseModel):
ConflictingSettingsError
If the Azure authentication type conflicts with the model being used.
"""
if self.auth_type == AuthType.AzureManagedIdentity and (
self.type == ModelType.OpenAIChat or self.type == ModelType.OpenAIEmbedding
if (
self.auth_type == AuthType.AzureManagedIdentity
and self.type != ModelType.AzureOpenAIChat
and self.type != ModelType.AzureOpenAIEmbedding
and self.model_provider != "azure" # indicates Litellm + AOI
):
msg = f"auth_type of azure_managed_identity is not supported for model type {self.type}. Please rerun `graphrag init` and set the auth_type to api_key."
raise ConflictingSettingsError(msg)
@ -94,6 +97,27 @@ class LanguageModelConfig(BaseModel):
msg = f"Model type {self.type} is not recognized, must be one of {ModelFactory.get_chat_models() + ModelFactory.get_embedding_models()}."
raise KeyError(msg)
model_provider: str | None = Field(
description="The model provider to use.",
default=language_model_defaults.model_provider,
)
def _validate_model_provider(self) -> None:
"""Validate the model provider.
Required when using Litellm.
Raises
------
KeyError
If the model provider is not recognized.
"""
if (self.type == ModelType.Chat or self.type == ModelType.Embedding) and (
self.model_provider is None or self.model_provider.strip() == ""
):
msg = f"Model provider must be specified when using type == {self.type}."
raise KeyError(msg)
model: str = Field(description="The LLM model to use.")
encoding_model: str = Field(
description="The encoding model to use",
@ -103,12 +127,27 @@ class LanguageModelConfig(BaseModel):
def _validate_encoding_model(self) -> None:
"""Validate the encoding model.
The default behavior is to use an encoding model that matches the LLM model.
LiteLLM supports 100+ models and their tokenization. There is no need to
set the encoding model when using the new LiteLLM provider as was done with fnllm provider.
Users can still manually specify a tiktoken based encoding model to use even with the LiteLLM provider
in which case the specified encoding model will be used regardless of the LLM model being used, even if
it is not an openai based model.
If not using LiteLLM provider, set the encoding model based on the LLM model name.
This is for backward compatibility with existing fnllm provider until fnllm is removed.
Raises
------
KeyError
If the model name is not recognized.
"""
if self.encoding_model.strip() == "":
if (
self.type != ModelType.Chat
and self.type != ModelType.Embedding
and self.encoding_model.strip() == ""
):
self.encoding_model = tiktoken.encoding_name_for_model(self.model)
api_base: str | None = Field(
@ -129,6 +168,7 @@ class LanguageModelConfig(BaseModel):
if (
self.type == ModelType.AzureOpenAIChat
or self.type == ModelType.AzureOpenAIEmbedding
or self.model_provider == "azure" # indicates Litellm + AOI
) and (self.api_base is None or self.api_base.strip() == ""):
raise AzureApiBaseMissingError(self.type)
@ -150,6 +190,7 @@ class LanguageModelConfig(BaseModel):
if (
self.type == ModelType.AzureOpenAIChat
or self.type == ModelType.AzureOpenAIEmbedding
or self.model_provider == "azure" # indicates Litellm + AOI
) and (self.api_version is None or self.api_version.strip() == ""):
raise AzureApiVersionMissingError(self.type)
@ -171,6 +212,7 @@ class LanguageModelConfig(BaseModel):
if (
self.type == ModelType.AzureOpenAIChat
or self.type == ModelType.AzureOpenAIEmbedding
or self.model_provider == "azure" # indicates Litellm + AOI
) and (self.deployment_name is None or self.deployment_name.strip() == ""):
raise AzureDeploymentNameMissingError(self.type)
@ -212,6 +254,14 @@ class LanguageModelConfig(BaseModel):
msg = f"Tokens per minute must be a non zero positive number, 'auto' or null. Suggested value: {language_model_defaults.tokens_per_minute}."
raise ValueError(msg)
if (
(self.type == ModelType.Chat or self.type == ModelType.Embedding)
and self.rate_limit_strategy is not None
and self.tokens_per_minute == "auto"
):
msg = f"tokens_per_minute cannot be set to 'auto' when using type '{self.type}'. Please set it to a positive integer or null to disable."
raise ValueError(msg)
requests_per_minute: int | Literal["auto"] | None = Field(
description="The number of requests per minute to use for the LLM service.",
default=language_model_defaults.requests_per_minute,
@ -230,6 +280,19 @@ class LanguageModelConfig(BaseModel):
msg = f"Requests per minute must be a non zero positive number, 'auto' or null. Suggested value: {language_model_defaults.requests_per_minute}."
raise ValueError(msg)
if (
(self.type == ModelType.Chat or self.type == ModelType.Embedding)
and self.rate_limit_strategy is not None
and self.requests_per_minute == "auto"
):
msg = f"requests_per_minute cannot be set to 'auto' when using type '{self.type}'. Please set it to a positive integer or null to disable."
raise ValueError(msg)
rate_limit_strategy: str | None = Field(
description="The rate limit strategy to use for the LLM service.",
default=language_model_defaults.rate_limit_strategy,
)
retry_strategy: str = Field(
description="The retry strategy to use for the LLM service.",
default=language_model_defaults.retry_strategy,
@ -318,6 +381,7 @@ class LanguageModelConfig(BaseModel):
@model_validator(mode="after")
def _validate_model(self):
self._validate_type()
self._validate_model_provider()
self._validate_auth_type()
self._validate_api_key()
self._validate_tokens_per_minute()

View File

@ -0,0 +1,4 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""Factory module."""

View File

@ -0,0 +1,68 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""Factory ABC."""
from abc import ABC
from collections.abc import Callable
from typing import Any, ClassVar, Generic, TypeVar
T = TypeVar("T", covariant=True)
class Factory(ABC, Generic[T]):
"""Abstract base class for factories."""
_instance: ClassVar["Factory | None"] = None
def __new__(cls, *args: Any, **kwargs: Any) -> "Factory":
"""Create a new instance of Factory if it does not exist."""
if cls._instance is None:
cls._instance = super().__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self):
if not hasattr(self, "_initialized"):
self._services: dict[str, Callable[..., T]] = {}
self._initialized = True
def __contains__(self, strategy: str) -> bool:
"""Check if a strategy is registered."""
return strategy in self._services
def keys(self) -> list[str]:
"""Get a list of registered strategy names."""
return list(self._services.keys())
def register(self, *, strategy: str, service_initializer: Callable[..., T]) -> None:
"""
Register a new service.
Args
----
strategy: The name of the strategy.
service_initializer: A callable that creates an instance of T.
"""
self._services[strategy] = service_initializer
def create(self, *, strategy: str, **kwargs: Any) -> T:
"""
Create a service instance based on the strategy.
Args
----
strategy: The name of the strategy.
**kwargs: Additional arguments to pass to the service initializer.
Returns
-------
An instance of T.
Raises
------
ValueError: If the strategy is not registered.
"""
if strategy not in self._services:
msg = f"Strategy '{strategy}' is not registered."
raise ValueError(msg)
return self._services[strategy](**kwargs)

View File

@ -11,7 +11,7 @@ import tiktoken
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.index.operations.chunk_text.typing import TextChunk
from graphrag.index.text_splitting.text_splitting import (
Tokenizer,
TokenChunkerOptions,
split_multiple_texts_on_tokens,
)
from graphrag.logger.progress import ProgressTicker
@ -45,7 +45,7 @@ def run_tokens(
encode, decode = get_encoding_fn(encoding_name)
return split_multiple_texts_on_tokens(
input,
Tokenizer(
TokenChunkerOptions(
chunk_overlap=chunk_overlap,
tokens_per_chunk=tokens_per_chunk,
encode=encode,

View File

@ -18,6 +18,7 @@ from graphrag.index.utils.is_null import is_null
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.logger.progress import ProgressTicker, progress_ticker
from graphrag.tokenizer.get_tokenizer import get_tokenizer
logger = logging.getLogger(__name__)
@ -79,7 +80,7 @@ def _get_splitter(
config: LanguageModelConfig, batch_max_tokens: int
) -> TokenTextSplitter:
return TokenTextSplitter(
encoding_name=config.encoding_model,
tokenizer=get_tokenizer(model_config=config),
chunk_size=batch_max_tokens,
)

View File

@ -5,16 +5,16 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable, Collection, Iterable
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from typing import Any, Literal, cast
from typing import Any, cast
import pandas as pd
import tiktoken
import graphrag.config.defaults as defs
from graphrag.index.operations.chunk_text.typing import TextChunk
from graphrag.logger.progress import ProgressTicker
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
EncodedText = list[int]
DecodeFn = Callable[[EncodedText], str]
@ -25,8 +25,8 @@ logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class Tokenizer:
"""Tokenizer data class."""
class TokenChunkerOptions:
"""TokenChunkerOptions data class."""
chunk_overlap: int
"""Overlap in tokens between chunks"""
@ -83,44 +83,18 @@ class NoopTextSplitter(TextSplitter):
class TokenTextSplitter(TextSplitter):
"""Token text splitter class definition."""
_allowed_special: Literal["all"] | set[str]
_disallowed_special: Literal["all"] | Collection[str]
def __init__(
self,
encoding_name: str = defs.ENCODING_MODEL,
model_name: str | None = None,
allowed_special: Literal["all"] | set[str] | None = None,
disallowed_special: Literal["all"] | Collection[str] = "all",
tokenizer: Tokenizer | None = None,
**kwargs: Any,
):
"""Init method definition."""
super().__init__(**kwargs)
if model_name is not None:
try:
enc = tiktoken.encoding_for_model(model_name)
except KeyError:
logger.exception(
"Model %s not found, using %s", model_name, encoding_name
)
enc = tiktoken.get_encoding(encoding_name)
else:
enc = tiktoken.get_encoding(encoding_name)
self._tokenizer = enc
self._allowed_special = allowed_special or set()
self._disallowed_special = disallowed_special
def encode(self, text: str) -> list[int]:
"""Encode the given text into an int-vector."""
return self._tokenizer.encode(
text,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
self._tokenizer = tokenizer or get_tokenizer()
def num_tokens(self, text: str) -> int:
"""Return the number of tokens in a string."""
return len(self.encode(text))
return len(self._tokenizer.encode(text))
def split_text(self, text: str | list[str]) -> list[str]:
"""Split text method."""
@ -132,17 +106,17 @@ class TokenTextSplitter(TextSplitter):
msg = f"Attempting to split a non-string value, actual is {type(text)}"
raise TypeError(msg)
tokenizer = Tokenizer(
token_chunker_options = TokenChunkerOptions(
chunk_overlap=self._chunk_overlap,
tokens_per_chunk=self._chunk_size,
decode=self._tokenizer.decode,
encode=lambda text: self.encode(text),
encode=self._tokenizer.encode,
)
return split_single_text_on_tokens(text=text, tokenizer=tokenizer)
return split_single_text_on_tokens(text=text, tokenizer=token_chunker_options)
def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]:
def split_single_text_on_tokens(text: str, tokenizer: TokenChunkerOptions) -> list[str]:
"""Split a single text and return chunks using the tokenizer."""
result = []
input_ids = tokenizer.encode(text)
@ -166,7 +140,7 @@ def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]:
# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471
# So we could have better control over the chunking process
def split_multiple_texts_on_tokens(
texts: list[str], tokenizer: Tokenizer, tick: ProgressTicker
texts: list[str], tokenizer: TokenChunkerOptions, tick: ProgressTicker
) -> list[TextChunk]:
"""Split multiple texts and return chunks with metadata using the tokenizer."""
result = []

View File

@ -14,6 +14,10 @@ from graphrag.language_model.providers.fnllm.models import (
OpenAIChatFNLLM,
OpenAIEmbeddingFNLLM,
)
from graphrag.language_model.providers.litellm.chat_model import LitellmChatModel
from graphrag.language_model.providers.litellm.embedding_model import (
LitellmEmbeddingModel,
)
class ModelFactory:
@ -105,6 +109,7 @@ ModelFactory.register_chat(
ModelFactory.register_chat(
ModelType.OpenAIChat.value, lambda **kwargs: OpenAIChatFNLLM(**kwargs)
)
ModelFactory.register_chat(ModelType.Chat, lambda **kwargs: LitellmChatModel(**kwargs))
ModelFactory.register_embedding(
ModelType.AzureOpenAIEmbedding.value,
@ -113,3 +118,6 @@ ModelFactory.register_embedding(
ModelFactory.register_embedding(
ModelType.OpenAIEmbedding.value, lambda **kwargs: OpenAIEmbeddingFNLLM(**kwargs)
)
ModelFactory.register_embedding(
ModelType.Embedding, lambda **kwargs: LitellmEmbeddingModel(**kwargs)
)

View File

@ -0,0 +1,4 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""GraphRAG LiteLLM module. Provides LiteLLM-based implementations of chat and embedding models."""

View File

@ -0,0 +1,413 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""Chat model implementation using Litellm."""
import inspect
import json
from collections.abc import AsyncGenerator, Generator
from typing import TYPE_CHECKING, Any, cast
import litellm
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from litellm import (
CustomStreamWrapper,
ModelResponse, # type: ignore
acompletion,
completion,
)
from pydantic import BaseModel, Field
from graphrag.config.defaults import COGNITIVE_SERVICES_AUDIENCE
from graphrag.config.enums import AuthType
from graphrag.language_model.providers.litellm.request_wrappers.with_cache import (
with_cache,
)
from graphrag.language_model.providers.litellm.request_wrappers.with_logging import (
with_logging,
)
from graphrag.language_model.providers.litellm.request_wrappers.with_rate_limiter import (
with_rate_limiter,
)
from graphrag.language_model.providers.litellm.request_wrappers.with_retries import (
with_retries,
)
from graphrag.language_model.providers.litellm.types import (
AFixedModelCompletion,
FixedModelCompletion,
)
if TYPE_CHECKING:
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.response.base import ModelResponse as MR # noqa: N817
litellm.suppress_debug_info = True
def _create_base_completions(
model_config: "LanguageModelConfig",
) -> tuple[FixedModelCompletion, AFixedModelCompletion]:
"""Wrap the base litellm completion function with the model configuration.
Args
----
model_config: The configuration for the language model.
Returns
-------
A tuple containing the synchronous and asynchronous completion functions.
"""
model_provider = model_config.model_provider
model = model_config.deployment_name or model_config.model
base_args: dict[str, Any] = {
"drop_params": True, # LiteLLM drop unsupported params for selected model.
"model": f"{model_provider}/{model}",
"timeout": model_config.request_timeout,
"top_p": model_config.top_p,
"n": model_config.n,
"temperature": model_config.temperature,
"frequency_penalty": model_config.frequency_penalty,
"presence_penalty": model_config.presence_penalty,
"api_base": model_config.api_base,
"api_version": model_config.api_version,
"api_key": model_config.api_key,
"organization": model_config.organization,
"proxy": model_config.proxy,
"audience": model_config.audience,
"max_tokens": model_config.max_tokens,
"max_completion_tokens": model_config.max_completion_tokens,
"reasoning_effort": model_config.reasoning_effort,
}
if model_config.auth_type == AuthType.AzureManagedIdentity:
if model_config.model_provider != "azure":
msg = "Azure Managed Identity authentication is only supported for Azure models."
raise ValueError(msg)
base_args["azure_ad_token_provider"] = get_bearer_token_provider(
DefaultAzureCredential(),
COGNITIVE_SERVICES_AUDIENCE,
)
def _base_completion(**kwargs: Any) -> ModelResponse | CustomStreamWrapper:
new_args = {**base_args, **kwargs}
if "name" in new_args:
new_args.pop("name")
return completion(**new_args)
async def _base_acompletion(**kwargs: Any) -> ModelResponse | CustomStreamWrapper:
new_args = {**base_args, **kwargs}
if "name" in new_args:
new_args.pop("name")
return await acompletion(**new_args)
return (_base_completion, _base_acompletion)
def _create_completions(
model_config: "LanguageModelConfig",
cache: "PipelineCache | None",
cache_key_prefix: str,
) -> tuple[FixedModelCompletion, AFixedModelCompletion]:
"""Wrap the base litellm completion function with the model configuration and additional features.
Wrap the base litellm completion function with instance variables based on the model configuration.
Then wrap additional features such as rate limiting, retries, and caching, if enabled.
Final function composition order:
- Logging(Cache(Retries(RateLimiter(ModelCompletion()))))
Args
----
model_config: The configuration for the language model.
cache: Optional cache for storing responses.
cache_key_prefix: Prefix for cache keys.
Returns
-------
A tuple containing the synchronous and asynchronous completion functions.
"""
completion, acompletion = _create_base_completions(model_config)
# TODO: For v2.x release, rpm/tpm can be int or str (auto) for backwards compatibility with fnllm.
# LiteLLM does not support "auto", so we have to check those values here.
# For v3 release, force rpm/tpm to be int and remove the type checks below
# and just check if rate_limit_strategy is enabled.
if model_config.rate_limit_strategy is not None:
rpm = (
model_config.requests_per_minute
if type(model_config.requests_per_minute) is int
else None
)
tpm = (
model_config.tokens_per_minute
if type(model_config.tokens_per_minute) is int
else None
)
if rpm is not None or tpm is not None:
completion, acompletion = with_rate_limiter(
sync_fn=completion,
async_fn=acompletion,
model_config=model_config,
rpm=rpm,
tpm=tpm,
)
if model_config.retry_strategy != "none":
completion, acompletion = with_retries(
sync_fn=completion,
async_fn=acompletion,
model_config=model_config,
)
if cache is not None:
completion, acompletion = with_cache(
sync_fn=completion,
async_fn=acompletion,
model_config=model_config,
cache=cache,
request_type="chat",
cache_key_prefix=cache_key_prefix,
)
completion, acompletion = with_logging(
sync_fn=completion,
async_fn=acompletion,
)
return (completion, acompletion)
class LitellmModelOutput(BaseModel):
"""A model representing the output from a language model."""
content: str = Field(description="The generated text content")
full_response: None = Field(
default=None, description="The full response from the model, if available"
)
class LitellmModelResponse(BaseModel):
"""A model representing the response from a language model."""
output: LitellmModelOutput = Field(description="The output from the model")
parsed_response: BaseModel | None = Field(
default=None, description="Parsed response from the model"
)
history: list = Field(
default_factory=list,
description="Conversation history including the prompt and response",
)
class LitellmChatModel:
"""LiteLLM-based Chat Model."""
def __init__(
self,
name: str,
config: "LanguageModelConfig",
cache: "PipelineCache | None" = None,
**kwargs: Any,
):
self.name = name
self.config = config
self.cache = cache.child(self.name) if cache else None
self.completion, self.acompletion = _create_completions(
config, self.cache, "chat"
)
def _get_kwargs(self, **kwargs: Any) -> dict[str, Any]:
"""Get model arguments supported by litellm."""
args_to_include = [
"name",
"modalities",
"prediction",
"audio",
"logit_bias",
"metadata",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
"logprobs",
"top_logprobs",
"parallel_tool_calls",
"web_search_options",
"extra_headers",
"functions",
"function_call",
"thinking",
]
new_args = {k: v for k, v in kwargs.items() if k in args_to_include}
# If using JSON, check if response_format should be a Pydantic model or just a general JSON object
if kwargs.get("json"):
new_args["response_format"] = {"type": "json_object"}
if (
"json_model" in kwargs
and inspect.isclass(kwargs["json_model"])
and issubclass(kwargs["json_model"], BaseModel)
):
new_args["response_format"] = kwargs["json_model"]
return new_args
async def achat(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> "MR":
"""
Generate a response for the given prompt and history.
Args
----
prompt: The prompt to generate a response for.
history: Optional chat history.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
LitellmModelResponse: The generated model response.
"""
new_kwargs = self._get_kwargs(**kwargs)
messages: list[dict[str, str]] = history or []
messages.append({"role": "user", "content": prompt})
response = await self.acompletion(messages=messages, stream=False, **new_kwargs) # type: ignore
messages.append({
"role": "assistant",
"content": response.choices[0].message.content or "", # type: ignore
})
parsed_response: BaseModel | None = None
if "response_format" in new_kwargs:
parsed_dict: dict[str, Any] = json.loads(
response.choices[0].message.content or "{}" # type: ignore
)
parsed_response = parsed_dict # type: ignore
if inspect.isclass(new_kwargs["response_format"]) and issubclass(
new_kwargs["response_format"], BaseModel
):
# If response_format is a pydantic model, instantiate it
model_initializer = cast(
"type[BaseModel]", new_kwargs["response_format"]
)
parsed_response = model_initializer(**parsed_dict)
return LitellmModelResponse(
output=LitellmModelOutput(
content=response.choices[0].message.content or "" # type: ignore
),
parsed_response=parsed_response,
history=messages,
)
async def achat_stream(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> AsyncGenerator[str, None]:
"""
Generate a response for the given prompt and history.
Args
----
prompt: The prompt to generate a response for.
history: Optional chat history.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
AsyncGenerator[str, None]: The generated response as a stream of strings.
"""
new_kwargs = self._get_kwargs(**kwargs)
messages: list[dict[str, str]] = history or []
messages.append({"role": "user", "content": prompt})
response = await self.acompletion(messages=messages, stream=True, **new_kwargs) # type: ignore
async for chunk in response: # type: ignore
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
def chat(self, prompt: str, history: list | None = None, **kwargs: Any) -> "MR":
"""
Generate a response for the given prompt and history.
Args
----
prompt: The prompt to generate a response for.
history: Optional chat history.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
LitellmModelResponse: The generated model response.
"""
new_kwargs = self._get_kwargs(**kwargs)
messages: list[dict[str, str]] = history or []
messages.append({"role": "user", "content": prompt})
response = self.completion(messages=messages, stream=False, **new_kwargs) # type: ignore
messages.append({
"role": "assistant",
"content": response.choices[0].message.content or "", # type: ignore
})
parsed_response: BaseModel | None = None
if "response_format" in new_kwargs:
parsed_dict: dict[str, Any] = json.loads(
response.choices[0].message.content or "{}" # type: ignore
)
parsed_response = parsed_dict # type: ignore
if inspect.isclass(new_kwargs["response_format"]) and issubclass(
new_kwargs["response_format"], BaseModel
):
# If response_format is a pydantic model, instantiate it
model_initializer = cast(
"type[BaseModel]", new_kwargs["response_format"]
)
parsed_response = model_initializer(**parsed_dict)
return LitellmModelResponse(
output=LitellmModelOutput(
content=response.choices[0].message.content or "" # type: ignore
),
parsed_response=parsed_response,
history=messages,
)
def chat_stream(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> Generator[str, None]:
"""
Generate a response for the given prompt and history.
Args
----
prompt: The prompt to generate a response for.
history: Optional chat history.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
Generator[str, None]: The generated response as a stream of strings.
"""
new_kwargs = self._get_kwargs(**kwargs)
messages: list[dict[str, str]] = history or []
messages.append({"role": "user", "content": prompt})
response = self.completion(messages=messages, stream=True, **new_kwargs) # type: ignore
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content: # type: ignore
yield chunk.choices[0].delta.content # type: ignore

View File

@ -0,0 +1,279 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""Embedding model implementation using Litellm."""
from typing import TYPE_CHECKING, Any
import litellm
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from litellm import (
EmbeddingResponse, # type: ignore
aembedding,
embedding,
)
from graphrag.config.defaults import COGNITIVE_SERVICES_AUDIENCE
from graphrag.config.enums import AuthType
from graphrag.language_model.providers.litellm.request_wrappers.with_cache import (
with_cache,
)
from graphrag.language_model.providers.litellm.request_wrappers.with_logging import (
with_logging,
)
from graphrag.language_model.providers.litellm.request_wrappers.with_rate_limiter import (
with_rate_limiter,
)
from graphrag.language_model.providers.litellm.request_wrappers.with_retries import (
with_retries,
)
from graphrag.language_model.providers.litellm.types import (
AFixedModelEmbedding,
FixedModelEmbedding,
)
if TYPE_CHECKING:
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.models.language_model_config import LanguageModelConfig
litellm.suppress_debug_info = True
def _create_base_embeddings(
model_config: "LanguageModelConfig",
) -> tuple[FixedModelEmbedding, AFixedModelEmbedding]:
"""Wrap the base litellm embedding function with the model configuration.
Args
----
model_config: The configuration for the language model.
Returns
-------
A tuple containing the synchronous and asynchronous embedding functions.
"""
model_provider = model_config.model_provider
model = model_config.deployment_name or model_config.model
base_args: dict[str, Any] = {
"drop_params": True, # LiteLLM drop unsupported params for selected model.
"model": f"{model_provider}/{model}",
"timeout": model_config.request_timeout,
"api_base": model_config.api_base,
"api_version": model_config.api_version,
"api_key": model_config.api_key,
"organization": model_config.organization,
"proxy": model_config.proxy,
"audience": model_config.audience,
}
if model_config.auth_type == AuthType.AzureManagedIdentity:
if model_config.model_provider != "azure":
msg = "Azure Managed Identity authentication is only supported for Azure models."
raise ValueError(msg)
base_args["azure_ad_token_provider"] = get_bearer_token_provider(
DefaultAzureCredential(),
COGNITIVE_SERVICES_AUDIENCE,
)
def _base_embedding(**kwargs: Any) -> EmbeddingResponse:
new_args = {**base_args, **kwargs}
if "name" in new_args:
new_args.pop("name")
return embedding(**new_args)
async def _base_aembedding(**kwargs: Any) -> EmbeddingResponse:
new_args = {**base_args, **kwargs}
if "name" in new_args:
new_args.pop("name")
return await aembedding(**new_args)
return (_base_embedding, _base_aembedding)
def _create_embeddings(
model_config: "LanguageModelConfig",
cache: "PipelineCache | None",
cache_key_prefix: str,
) -> tuple[FixedModelEmbedding, AFixedModelEmbedding]:
"""Wrap the base litellm embedding function with the model configuration and additional features.
Wrap the base litellm embedding function with instance variables based on the model configuration.
Then wrap additional features such as rate limiting, retries, and caching, if enabled.
Final function composition order:
- Logging(Cache(Retries(RateLimiter(ModelEmbedding()))))
Args
----
model_config: The configuration for the language model.
cache: Optional cache for storing responses.
cache_key_prefix: Prefix for cache keys.
Returns
-------
A tuple containing the synchronous and asynchronous embedding functions.
"""
embedding, aembedding = _create_base_embeddings(model_config)
# TODO: For v2.x release, rpm/tpm can be int or str (auto) for backwards compatibility with fnllm.
# LiteLLM does not support "auto", so we have to check those values here.
# For v3 release, force rpm/tpm to be int and remove the type checks below
# and just check if rate_limit_strategy is enabled.
if model_config.rate_limit_strategy is not None:
rpm = (
model_config.requests_per_minute
if type(model_config.requests_per_minute) is int
else None
)
tpm = (
model_config.tokens_per_minute
if type(model_config.tokens_per_minute) is int
else None
)
if rpm is not None or tpm is not None:
embedding, aembedding = with_rate_limiter(
sync_fn=embedding,
async_fn=aembedding,
model_config=model_config,
rpm=rpm,
tpm=tpm,
)
if model_config.retry_strategy != "none":
embedding, aembedding = with_retries(
sync_fn=embedding,
async_fn=aembedding,
model_config=model_config,
)
if cache is not None:
embedding, aembedding = with_cache(
sync_fn=embedding,
async_fn=aembedding,
model_config=model_config,
cache=cache,
request_type="embedding",
cache_key_prefix=cache_key_prefix,
)
embedding, aembedding = with_logging(
sync_fn=embedding,
async_fn=aembedding,
)
return (embedding, aembedding)
class LitellmEmbeddingModel:
"""LiteLLM-based Embedding Model."""
def __init__(
self,
name: str,
config: "LanguageModelConfig",
cache: "PipelineCache | None" = None,
**kwargs: Any,
):
self.name = name
self.config = config
self.cache = cache.child(self.name) if cache else None
self.embedding, self.aembedding = _create_embeddings(
config, self.cache, "embeddings"
)
def _get_kwargs(self, **kwargs: Any) -> dict[str, Any]:
"""Get model arguments supported by litellm."""
args_to_include = [
"name",
"dimensions",
"encoding_format",
"timeout",
"user",
]
return {k: v for k, v in kwargs.items() if k in args_to_include}
async def aembed_batch(
self, text_list: list[str], **kwargs: Any
) -> list[list[float]]:
"""
Batch generate embeddings.
Args
----
text_list: A batch of text inputs to generate embeddings for.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
A Batch of embeddings.
"""
new_kwargs = self._get_kwargs(**kwargs)
response = await self.aembedding(input=text_list, **new_kwargs)
return [emb.get("embedding", []) for emb in response.data]
async def aembed(self, text: str, **kwargs: Any) -> list[float]:
"""
Async embed.
Args:
text: The text to generate an embedding for.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
An embedding.
"""
new_kwargs = self._get_kwargs(**kwargs)
response = await self.aembedding(input=[text], **new_kwargs)
return (
response.data[0].get("embedding", [])
if response.data and response.data[0]
else []
)
def embed_batch(self, text_list: list[str], **kwargs: Any) -> list[list[float]]:
"""
Batch generate embeddings.
Args:
text_list: A batch of text inputs to generate embeddings for.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
A Batch of embeddings.
"""
new_kwargs = self._get_kwargs(**kwargs)
response = self.embedding(input=text_list, **new_kwargs)
return [emb.get("embedding", []) for emb in response.data]
def embed(self, text: str, **kwargs: Any) -> list[float]:
"""
Embed a single text input.
Args:
text: The text to generate an embedding for.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
An embedding.
"""
new_kwargs = self._get_kwargs(**kwargs)
response = self.embedding(input=[text], **new_kwargs)
return (
response.data[0].get("embedding", [])
if response.data and response.data[0]
else []
)

View File

@ -0,0 +1,140 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""
LiteLLM cache key generation.
Modeled after the fnllm cache key generation.
https://github.com/microsoft/essex-toolkit/blob/23d3077b65c0e8f1d89c397a2968fe570a25f790/python/fnllm/fnllm/caching/base.py#L50
"""
import hashlib
import inspect
import json
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel
if TYPE_CHECKING:
from graphrag.config.models.language_model_config import LanguageModelConfig
_CACHE_VERSION = 3
"""
If there's a breaking change in what we cache, we should increment this version number to invalidate existing caches.
fnllm was on cache version 2 and though we generate
similar cache keys, the objects stored in cache by fnllm and litellm are different.
Using litellm model providers will not be able to reuse caches generated by fnllm
thus we start with version 3 for litellm.
"""
def get_cache_key(
model_config: "LanguageModelConfig",
prefix: str,
messages: str | None = None,
input: str | None = None,
**kwargs: Any,
) -> str:
"""Generate a cache key based on the model configuration and input arguments.
Modeled after the fnllm cache key generation.
https://github.com/microsoft/essex-toolkit/blob/23d3077b65c0e8f1d89c397a2968fe570a25f790/python/fnllm/fnllm/caching/base.py#L50
Args
____
model_config: The configuration of the language model.
prefix: A prefix for the cache key.
**kwargs: Additional model input parameters.
Returns
-------
`{prefix}_{data_hash}_v{version}` if prefix is provided.
"""
cache_key: dict[str, Any] = {
"parameters": _get_parameters(model_config, **kwargs),
}
if messages is not None and input is not None:
msg = "Only one of 'messages' or 'input' should be provided."
raise ValueError(msg)
if messages is not None:
cache_key["messages"] = messages
elif input is not None:
cache_key["input"] = input
else:
msg = "Either 'messages' or 'input' must be provided."
raise ValueError(msg)
data_hash = _hash(json.dumps(cache_key, sort_keys=True))
name = kwargs.get("name")
if name:
prefix += f"_{name}"
return f"{prefix}_{data_hash}_v{_CACHE_VERSION}"
def _get_parameters(
model_config: "LanguageModelConfig",
**kwargs: Any,
) -> dict[str, Any]:
"""Pluck out the parameters that define a cache key.
Use the same parameters as fnllm except request timeout.
- embeddings: https://github.com/microsoft/essex-toolkit/blob/main/python/fnllm/fnllm/openai/types/embeddings/parameters.py#L12
- chat: https://github.com/microsoft/essex-toolkit/blob/main/python/fnllm/fnllm/openai/types/chat/parameters.py#L25
Args
____
model_config: The configuration of the language model.
**kwargs: Additional model input parameters.
Returns
-------
dict[str, Any]: A dictionary of parameters that define the cache key.
"""
parameters = {
"model": model_config.deployment_name or model_config.model,
"frequency_penalty": model_config.frequency_penalty,
"max_tokens": model_config.max_tokens,
"max_completion_tokens": model_config.max_completion_tokens,
"n": model_config.n,
"presence_penalty": model_config.presence_penalty,
"temperature": model_config.temperature,
"top_p": model_config.top_p,
"reasoning_effort": model_config.reasoning_effort,
}
keys_to_cache = [
"function_call",
"functions",
"logit_bias",
"logprobs",
"parallel_tool_calls",
"seed",
"service_tier",
"stop",
"tool_choice",
"tools",
"top_logprobs",
"user",
"dimensions",
"encoding_format",
]
parameters.update({key: kwargs.get(key) for key in keys_to_cache if key in kwargs})
response_format = kwargs.get("response_format")
if inspect.isclass(response_format) and issubclass(response_format, BaseModel):
parameters["response_format"] = str(response_format)
elif response_format is not None:
parameters["response_format"] = response_format
return parameters
def _hash(input: str) -> str:
"""Generate a hash for the input string."""
return hashlib.sha256(input.encode()).hexdigest()

View File

@ -0,0 +1,4 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM completion/embedding function wrappers."""

View File

@ -0,0 +1,107 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM completion/embedding cache wrapper."""
import asyncio
from typing import TYPE_CHECKING, Any, Literal
from litellm import EmbeddingResponse, ModelResponse # type: ignore
from graphrag.language_model.providers.litellm.get_cache_key import get_cache_key
from graphrag.language_model.providers.litellm.types import (
AsyncLitellmRequestFunc,
LitellmRequestFunc,
)
if TYPE_CHECKING:
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.models.language_model_config import LanguageModelConfig
def with_cache(
*,
sync_fn: LitellmRequestFunc,
async_fn: AsyncLitellmRequestFunc,
model_config: "LanguageModelConfig",
cache: "PipelineCache",
request_type: Literal["chat", "embedding"],
cache_key_prefix: str,
) -> tuple[LitellmRequestFunc, AsyncLitellmRequestFunc]:
"""
Wrap the synchronous and asynchronous request functions with caching.
Args
----
sync_fn: The synchronous chat/embedding request function to wrap.
async_fn: The asynchronous chat/embedding request function to wrap.
model_config: The configuration for the language model.
cache: The cache to use for storing responses.
request_type: The type of request being made, either "chat" or "embedding".
cache_key_prefix: The prefix to use for cache keys.
Returns
-------
A tuple containing the wrapped synchronous and asynchronous chat/embedding request functions.
"""
def _wrapped_with_cache(**kwargs: Any) -> Any:
is_streaming = kwargs.get("stream", False)
if is_streaming:
return sync_fn(**kwargs)
cache_key = get_cache_key(
model_config=model_config, prefix=cache_key_prefix, **kwargs
)
event_loop = asyncio.get_event_loop()
cached_response = event_loop.run_until_complete(cache.get(cache_key))
if (
cached_response is not None
and isinstance(cached_response, dict)
and "response" in cached_response
and cached_response["response"] is not None
and isinstance(cached_response["response"], dict)
):
try:
if request_type == "chat":
return ModelResponse(**cached_response["response"])
return EmbeddingResponse(**cached_response["response"])
except Exception: # noqa: BLE001
# Try to retrieve value from cache but if it fails, continue
# to make the request.
...
response = sync_fn(**kwargs)
event_loop.run_until_complete(
cache.set(cache_key, {"response": response.model_dump()})
)
return response
async def _wrapped_with_cache_async(
**kwargs: Any,
) -> Any:
is_streaming = kwargs.get("stream", False)
if is_streaming:
return await async_fn(**kwargs)
cache_key = get_cache_key(
model_config=model_config, prefix=cache_key_prefix, **kwargs
)
cached_response = await cache.get(cache_key)
if (
cached_response is not None
and isinstance(cached_response, dict)
and "response" in cached_response
and cached_response["response"] is not None
and isinstance(cached_response["response"], dict)
):
try:
if request_type == "chat":
return ModelResponse(**cached_response["response"])
return EmbeddingResponse(**cached_response["response"])
except Exception: # noqa: BLE001
# Try to retrieve value from cache but if it fails, continue
# to make the request.
...
response = await async_fn(**kwargs)
await cache.set(cache_key, {"response": response.model_dump()})
return response
return (_wrapped_with_cache, _wrapped_with_cache_async)

View File

@ -0,0 +1,56 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM completion/embedding logging wrapper."""
import logging
from typing import Any
from graphrag.language_model.providers.litellm.types import (
AsyncLitellmRequestFunc,
LitellmRequestFunc,
)
logger = logging.getLogger(__name__)
def with_logging(
*,
sync_fn: LitellmRequestFunc,
async_fn: AsyncLitellmRequestFunc,
) -> tuple[LitellmRequestFunc, AsyncLitellmRequestFunc]:
"""
Wrap the synchronous and asynchronous request functions with retries.
Args
----
sync_fn: The synchronous chat/embedding request function to wrap.
async_fn: The asynchronous chat/embedding request function to wrap.
model_config: The configuration for the language model.
Returns
-------
A tuple containing the wrapped synchronous and asynchronous chat/embedding request functions.
"""
def _wrapped_with_logging(**kwargs: Any) -> Any:
try:
return sync_fn(**kwargs)
except Exception as e:
logger.exception(
f"with_logging: Request failed with exception={e}", # noqa: G004, TRY401
)
raise
async def _wrapped_with_logging_async(
**kwargs: Any,
) -> Any:
try:
return await async_fn(**kwargs)
except Exception as e:
logger.exception(
f"with_logging: Async request failed with exception={e}", # noqa: G004, TRY401
)
raise
return (_wrapped_with_logging, _wrapped_with_logging_async)

View File

@ -0,0 +1,97 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM completion/embedding rate limiter wrapper."""
from typing import TYPE_CHECKING, Any
from litellm import token_counter # type: ignore
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter_factory import (
RateLimiterFactory,
)
from graphrag.language_model.providers.litellm.types import (
AsyncLitellmRequestFunc,
LitellmRequestFunc,
)
if TYPE_CHECKING:
from graphrag.config.models.language_model_config import LanguageModelConfig
def with_rate_limiter(
*,
sync_fn: LitellmRequestFunc,
async_fn: AsyncLitellmRequestFunc,
model_config: "LanguageModelConfig",
rpm: int | None = None,
tpm: int | None = None,
) -> tuple[LitellmRequestFunc, AsyncLitellmRequestFunc]:
"""
Wrap the synchronous and asynchronous request functions with rate limiting.
Args
----
sync_fn: The synchronous chat/embedding request function to wrap.
async_fn: The asynchronous chat/embedding request function to wrap.
model_config: The configuration for the language model.
processing_event: A threading event that can be used to pause the rate limiter.
rpm: An optional requests per minute limit.
tpm: An optional tokens per minute limit.
If `rpm` and `tpm` is set to 0 or None, rate limiting is disabled.
Returns
-------
A tuple containing the wrapped synchronous and asynchronous chat/embedding request functions.
"""
rate_limiter_factory = RateLimiterFactory()
if (
model_config.rate_limit_strategy is None
or model_config.rate_limit_strategy not in rate_limiter_factory
):
msg = f"Rate Limiter strategy '{model_config.rate_limit_strategy}' is none or not registered. Available strategies: {', '.join(rate_limiter_factory.keys())}"
raise ValueError(msg)
rate_limiter_service = rate_limiter_factory.create(
strategy=model_config.rate_limit_strategy, rpm=rpm, tpm=tpm
)
max_tokens = model_config.max_completion_tokens or model_config.max_tokens or 0
def _wrapped_with_rate_limiter(**kwargs: Any) -> Any:
token_count = max_tokens
if "messages" in kwargs:
token_count += token_counter(
model=model_config.model,
messages=kwargs["messages"],
)
elif "input" in kwargs:
token_count += token_counter(
model=model_config.model,
text=kwargs["input"],
)
with rate_limiter_service.acquire(token_count=token_count):
return sync_fn(**kwargs)
async def _wrapped_with_rate_limiter_async(
**kwargs: Any,
) -> Any:
token_count = max_tokens
if "messages" in kwargs:
token_count += token_counter(
model=model_config.model,
messages=kwargs["messages"],
)
elif "input" in kwargs:
token_count += token_counter(
model=model_config.model,
text=kwargs["input"],
)
with rate_limiter_service.acquire(token_count=token_count):
return await async_fn(**kwargs)
return (_wrapped_with_rate_limiter, _wrapped_with_rate_limiter_async)

View File

@ -0,0 +1,54 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM completion/embedding retries wrapper."""
from typing import TYPE_CHECKING, Any
from graphrag.language_model.providers.litellm.services.retry.retry_factory import (
RetryFactory,
)
from graphrag.language_model.providers.litellm.types import (
AsyncLitellmRequestFunc,
LitellmRequestFunc,
)
if TYPE_CHECKING:
from graphrag.config.models.language_model_config import LanguageModelConfig
def with_retries(
*,
sync_fn: LitellmRequestFunc,
async_fn: AsyncLitellmRequestFunc,
model_config: "LanguageModelConfig",
) -> tuple[LitellmRequestFunc, AsyncLitellmRequestFunc]:
"""
Wrap the synchronous and asynchronous request functions with retries.
Args
----
sync_fn: The synchronous chat/embedding request function to wrap.
async_fn: The asynchronous chat/embedding request function to wrap.
model_config: The configuration for the language model.
Returns
-------
A tuple containing the wrapped synchronous and asynchronous chat/embedding request functions.
"""
retry_factory = RetryFactory()
retry_service = retry_factory.create(
strategy=model_config.retry_strategy,
max_attempts=model_config.max_retries,
max_retry_wait=model_config.max_retry_wait,
)
def _wrapped_with_retries(**kwargs: Any) -> Any:
return retry_service.retry(func=sync_fn, **kwargs)
async def _wrapped_with_retries_async(
**kwargs: Any,
) -> Any:
return await retry_service.aretry(func=async_fn, **kwargs)
return (_wrapped_with_retries, _wrapped_with_retries_async)

View File

@ -0,0 +1,4 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM Services."""

View File

@ -0,0 +1,4 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM Rate Limiter."""

View File

@ -0,0 +1,37 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM Rate Limiter."""
from abc import ABC, abstractmethod
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any
class RateLimiter(ABC):
"""Abstract base class for rate limiters."""
@abstractmethod
def __init__(
self,
/,
**kwargs: Any,
) -> None: ...
@abstractmethod
@contextmanager
def acquire(self, *, token_count: int) -> Iterator[None]:
"""
Acquire Rate Limiter.
Args
----
token_count: The estimated number of tokens for the current request.
Yields
------
None: This context manager does not return any value.
"""
msg = "RateLimiter subclasses must implement the acquire method."
raise NotImplementedError(msg)

View File

@ -0,0 +1,22 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM Rate Limiter Factory."""
from graphrag.config.defaults import DEFAULT_RATE_LIMITER_SERVICES
from graphrag.factory.factory import Factory
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter import (
RateLimiter,
)
class RateLimiterFactory(Factory[RateLimiter]):
"""Singleton factory for creating rate limiter services."""
rate_limiter_factory = RateLimiterFactory()
for service_name, service_cls in DEFAULT_RATE_LIMITER_SERVICES.items():
rate_limiter_factory.register(
strategy=service_name, service_initializer=service_cls
)

View File

@ -0,0 +1,133 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM Static Rate Limiter."""
import threading
import time
from collections import deque
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter import (
RateLimiter,
)
class StaticRateLimiter(RateLimiter):
"""Static Rate Limiter implementation."""
def __init__(
self,
*,
rpm: int | None = None,
tpm: int | None = None,
default_stagger: float = 0.0,
period_in_seconds: int = 60,
**kwargs: Any,
):
if rpm is None and tpm is None:
msg = "Both TPM and RPM cannot be None (disabled), one or both must be set to a positive integer."
raise ValueError(msg)
if (rpm is not None and rpm <= 0) or (tpm is not None and tpm <= 0):
msg = "RPM and TPM must be either None (disabled) or positive integers."
raise ValueError(msg)
if default_stagger < 0:
msg = "Default stagger must be a >= 0."
raise ValueError(msg)
if period_in_seconds <= 0:
msg = "Period in seconds must be a positive integer."
raise ValueError(msg)
self.rpm = rpm
self.tpm = tpm
self._lock = threading.Lock()
self.rate_queue: deque[float] = deque()
self.token_queue: deque[int] = deque()
self.period_in_seconds = period_in_seconds
self._last_time: float | None = None
self.stagger = default_stagger
if self.rpm is not None and self.rpm > 0:
self.stagger = self.period_in_seconds / self.rpm
@contextmanager
def acquire(self, *, token_count: int) -> Iterator[None]:
"""
Acquire Rate Limiter.
Args
----
token_count: The estimated number of tokens for the current request.
Yields
------
None: This context manager does not return any value.
"""
while True:
with self._lock:
current_time = time.time()
# Use two sliding windows to keep track of #requests and tokens per period
# Drop old requests and tokens out of the sliding windows
while (
len(self.rate_queue) > 0
and self.rate_queue[0] < current_time - self.period_in_seconds
):
self.rate_queue.popleft()
self.token_queue.popleft()
# If sliding window still exceed request limit, wait again
# Waiting requires reacquiring the lock, allowing other threads
# to see if their request fits within the rate limiting windows
# Makes more sense for token limit than request limit
if (
self.rpm is not None
and self.rpm > 0
and len(self.rate_queue) >= self.rpm
):
continue
# Check if current token window exceeds token limit
# If it does, wait again
# This does not account for the tokens from the current request
# This is intentional, as we want to allow the current request
# to be processed if it is larger than the tpm but smaller than context window.
# tpm is a rate/soft limit and not the hard limit of context window limits.
if (
self.tpm is not None
and self.tpm > 0
and sum(self.token_queue) >= self.tpm
):
continue
# This check accounts for the current request token usage
# is within the token limits bound.
# If the current requests token limit exceeds the token limit,
# Then let it be processed.
if (
self.tpm is not None
and self.tpm > 0
and token_count <= self.tpm
and sum(self.token_queue) + token_count > self.tpm
):
continue
# If there was a previous call, check if we need to stagger
if (
self.stagger > 0
and (
self._last_time # is None if this is the first hit to the rate limiter
and current_time - self._last_time
< self.stagger # If more time has passed than the stagger time, we can proceed
)
):
time.sleep(self.stagger - (current_time - self._last_time))
current_time = time.time()
# Add the current request to the sliding window
self.rate_queue.append(current_time)
self.token_queue.append(token_count)
self._last_time = current_time
break
yield

View File

@ -0,0 +1,4 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM Retry Services."""

View File

@ -0,0 +1,83 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM Exponential Retry Service."""
import asyncio
import logging
import random
import time
from collections.abc import Awaitable, Callable
from typing import Any
from graphrag.language_model.providers.litellm.services.retry.retry import Retry
logger = logging.getLogger(__name__)
class ExponentialRetry(Retry):
"""LiteLLM Exponential Retry Service."""
def __init__(
self,
*,
max_attempts: int = 5,
base_delay: float = 2.0,
jitter: bool = True,
**kwargs: Any,
):
if max_attempts <= 0:
msg = "max_attempts must be greater than 0."
raise ValueError(msg)
if base_delay <= 1.0:
msg = "base_delay must be greater than 1.0."
raise ValueError(msg)
self._max_attempts = max_attempts
self._base_delay = base_delay
self._jitter = jitter
def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any:
"""Retry a synchronous function."""
retries = 0
delay = 1.0 # Initial delay in seconds
while True:
try:
return func(**kwargs)
except Exception as e:
if retries >= self._max_attempts:
logger.exception(
f"ExponentialRetry: Max retries exceeded, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)
raise
retries += 1
delay *= self._base_delay
logger.exception(
f"ExponentialRetry: Request failed, retrying, retries={retries}, delay={delay}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)
time.sleep(delay + (self._jitter * random.uniform(0, 1))) # noqa: S311
async def aretry(
self,
func: Callable[..., Awaitable[Any]],
**kwargs: Any,
) -> Any:
"""Retry an asynchronous function."""
retries = 0
delay = 1.0 # Initial delay in seconds
while True:
try:
return await func(**kwargs)
except Exception as e:
if retries >= self._max_attempts:
logger.exception(
f"ExponentialRetry: Max retries exceeded, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)
raise
retries += 1
delay *= self._base_delay
logger.exception(
f"ExponentialRetry: Request failed, retrying, retries={retries}, delay={delay}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)
await asyncio.sleep(delay + (self._jitter * random.uniform(0, 1))) # noqa: S311

View File

@ -0,0 +1,77 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM Incremental Wait Retry Service."""
import asyncio
import logging
import time
from collections.abc import Awaitable, Callable
from typing import Any
from graphrag.language_model.providers.litellm.services.retry.retry import Retry
logger = logging.getLogger(__name__)
class IncrementalWaitRetry(Retry):
"""LiteLLM Incremental Wait Retry Service."""
def __init__(
self,
*,
max_retry_wait: float,
max_attempts: int = 5,
**kwargs: Any,
):
if max_attempts <= 0 or max_retry_wait <= 0:
msg = "max_attempts and max_retry_wait must be greater than 0."
raise ValueError(msg)
self._max_attempts = max_attempts
self._max_retry_wait = max_retry_wait
self._increment = max_retry_wait / max_attempts
def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any:
"""Retry a synchronous function."""
retries = 0
delay = 0.0
while True:
try:
return func(**kwargs)
except Exception as e:
if retries >= self._max_attempts:
logger.exception(
f"IncrementalWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)
raise
retries += 1
delay += self._increment
logger.exception(
f"IncrementalWaitRetry: Request failed, retrying after incremental delay, retries={retries}, delay={delay}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)
time.sleep(delay)
async def aretry(
self,
func: Callable[..., Awaitable[Any]],
**kwargs: Any,
) -> Any:
"""Retry an asynchronous function."""
retries = 0
delay = 0.0
while True:
try:
return await func(**kwargs)
except Exception as e:
if retries >= self._max_attempts:
logger.exception(
f"IncrementalWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)
raise
retries += 1
delay += self._increment
logger.exception(
f"IncrementalWaitRetry: Request failed, retrying after incremental delay, retries={retries}, delay={delay}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)
await asyncio.sleep(delay)

View File

@ -0,0 +1,66 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM Native Retry Service."""
import logging
from collections.abc import Awaitable, Callable
from typing import Any
from graphrag.language_model.providers.litellm.services.retry.retry import Retry
logger = logging.getLogger(__name__)
class NativeRetry(Retry):
"""LiteLLM Native Retry Service."""
def __init__(
self,
*,
max_attempts: int = 5,
**kwargs: Any,
):
if max_attempts <= 0:
msg = "max_attempts must be greater than 0."
raise ValueError(msg)
self._max_attempts = max_attempts
def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any:
"""Retry a synchronous function."""
retries = 0
while True:
try:
return func(**kwargs)
except Exception as e:
if retries >= self._max_attempts:
logger.exception(
f"NativeRetry: Max retries exceeded, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)
raise
retries += 1
logger.exception(
f"NativeRetry: Request failed, immediately retrying, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)
async def aretry(
self,
func: Callable[..., Awaitable[Any]],
**kwargs: Any,
) -> Any:
"""Retry an asynchronous function."""
retries = 0
while True:
try:
return await func(**kwargs)
except Exception as e:
if retries >= self._max_attempts:
logger.exception(
f"NativeRetry: Max retries exceeded, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)
raise
retries += 1
logger.exception(
f"NativeRetry: Request failed, immediately retrying, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)

View File

@ -0,0 +1,75 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM Random Wait Retry Service."""
import asyncio
import logging
import random
import time
from collections.abc import Awaitable, Callable
from typing import Any
from graphrag.language_model.providers.litellm.services.retry.retry import Retry
logger = logging.getLogger(__name__)
class RandomWaitRetry(Retry):
"""LiteLLM Random Wait Retry Service."""
def __init__(
self,
*,
max_retry_wait: float,
max_attempts: int = 5,
**kwargs: Any,
):
if max_attempts <= 0 or max_retry_wait <= 0:
msg = "max_attempts and max_retry_wait must be greater than 0."
raise ValueError(msg)
self._max_attempts = max_attempts
self._max_retry_wait = max_retry_wait
def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any:
"""Retry a synchronous function."""
retries = 0
while True:
try:
return func(**kwargs)
except Exception as e:
if retries >= self._max_attempts:
logger.exception(
f"RandomWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)
raise
retries += 1
delay = random.uniform(0, self._max_retry_wait) # noqa: S311
logger.exception(
f"RandomWaitRetry: Request failed, retrying after random delay, retries={retries}, delay={delay}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)
time.sleep(delay)
async def aretry(
self,
func: Callable[..., Awaitable[Any]],
**kwargs: Any,
) -> Any:
"""Retry an asynchronous function."""
retries = 0
while True:
try:
return await func(**kwargs)
except Exception as e:
if retries >= self._max_attempts:
logger.exception(
f"RandomWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)
raise
retries += 1
delay = random.uniform(0, self._max_retry_wait) # noqa: S311
logger.exception(
f"RandomWaitRetry: Request failed, retrying after random delay, retries={retries}, delay={delay}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)
await asyncio.sleep(delay)

View File

@ -0,0 +1,33 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM Retry Abstract Base Class."""
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from typing import Any
class Retry(ABC):
"""LiteLLM Retry Abstract Base Class."""
@abstractmethod
def __init__(self, /, **kwargs: Any):
msg = "Retry subclasses must implement the __init__ method."
raise NotImplementedError(msg)
@abstractmethod
def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any:
"""Retry a synchronous function."""
msg = "Subclasses must implement this method"
raise NotImplementedError(msg)
@abstractmethod
async def aretry(
self,
func: Callable[..., Awaitable[Any]],
**kwargs: Any,
) -> Any:
"""Retry an asynchronous function."""
msg = "Subclasses must implement this method"
raise NotImplementedError(msg)

View File

@ -0,0 +1,18 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM Retry Factory."""
from graphrag.config.defaults import DEFAULT_RETRY_SERVICES
from graphrag.factory.factory import Factory
from graphrag.language_model.providers.litellm.services.retry.retry import Retry
class RetryFactory(Factory[Retry]):
"""Singleton factory for creating retry services."""
retry_factory = RetryFactory()
for service_name, service_cls in DEFAULT_RETRY_SERVICES.items():
retry_factory.register(strategy=service_name, service_initializer=service_cls)

View File

@ -0,0 +1,235 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM types."""
from typing import (
Any,
Protocol,
runtime_checkable,
)
from litellm import (
AnthropicThinkingParam,
BaseModel,
ChatCompletionAudioParam,
ChatCompletionModality,
ChatCompletionPredictionContentParam,
CustomStreamWrapper,
EmbeddingResponse, # type: ignore
ModelResponse, # type: ignore
OpenAIWebSearchOptions,
)
from openai.types.chat.chat_completion import (
ChatCompletion,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from openai.types.completion_usage import (
CompletionTokensDetails,
CompletionUsage,
PromptTokensDetails,
)
from openai.types.create_embedding_response import CreateEmbeddingResponse, Usage
from openai.types.embedding import Embedding
LMChatCompletionMessageParam = ChatCompletionMessageParam | dict[str, str]
LMChatCompletion = ChatCompletion
LMChoice = Choice
LMChatCompletionMessage = ChatCompletionMessage
LMChatCompletionChunk = ChatCompletionChunk
LMChoiceChunk = ChunkChoice
LMChoiceDelta = ChoiceDelta
LMCompletionUsage = CompletionUsage
LMPromptTokensDetails = PromptTokensDetails
LMCompletionTokensDetails = CompletionTokensDetails
LMEmbeddingResponse = CreateEmbeddingResponse
LMEmbedding = Embedding
LMEmbeddingUsage = Usage
@runtime_checkable
class FixedModelCompletion(Protocol):
"""
Synchronous chat completion function.
Same signature as litellm.completion but without the `model` parameter
as this is already set in the model configuration.
"""
def __call__(
self,
*,
messages: list = [], # type: ignore # noqa: B006
stream: bool | None = None,
stream_options: dict | None = None, # type: ignore
stop=None, # type: ignore
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
modalities: list[ChatCompletionModality] | None = None,
prediction: ChatCompletionPredictionContentParam | None = None,
audio: ChatCompletionAudioParam | None = None,
logit_bias: dict | None = None, # type: ignore
user: str | None = None,
# openai v1.0+ new params
response_format: dict | type[BaseModel] | None = None, # type: ignore
seed: int | None = None,
tools: list | None = None, # type: ignore
tool_choice: str | dict | None = None, # type: ignore
logprobs: bool | None = None,
top_logprobs: int | None = None,
parallel_tool_calls: bool | None = None,
web_search_options: OpenAIWebSearchOptions | None = None,
deployment_id=None, # type: ignore
extra_headers: dict | None = None, # type: ignore
# soon to be deprecated params by OpenAI
functions: list | None = None, # type: ignore
function_call: str | None = None,
# Optional liteLLM function params
thinking: AnthropicThinkingParam | None = None,
**kwargs: Any,
) -> ModelResponse | CustomStreamWrapper:
"""Chat completion function."""
...
@runtime_checkable
class AFixedModelCompletion(Protocol):
"""
Asynchronous chat completion function.
Same signature as litellm.acompletion but without the `model` parameter
as this is already set in the model configuration.
"""
async def __call__(
self,
*,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: list = [], # type: ignore # noqa: B006
stream: bool | None = None,
stream_options: dict | None = None, # type: ignore
stop=None, # type: ignore
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
modalities: list[ChatCompletionModality] | None = None,
prediction: ChatCompletionPredictionContentParam | None = None,
audio: ChatCompletionAudioParam | None = None,
logit_bias: dict | None = None, # type: ignore
user: str | None = None,
# openai v1.0+ new params
response_format: dict | type[BaseModel] | None = None, # type: ignore
seed: int | None = None,
tools: list | None = None, # type: ignore
tool_choice: str | dict | None = None, # type: ignore
logprobs: bool | None = None,
top_logprobs: int | None = None,
parallel_tool_calls: bool | None = None,
web_search_options: OpenAIWebSearchOptions | None = None,
deployment_id=None, # type: ignore
extra_headers: dict | None = None, # type: ignore
# soon to be deprecated params by OpenAI
functions: list | None = None, # type: ignore
function_call: str | None = None,
# Optional liteLLM function params
thinking: AnthropicThinkingParam | None = None,
**kwargs: Any,
) -> ModelResponse | CustomStreamWrapper:
"""Chat completion function."""
...
@runtime_checkable
class FixedModelEmbedding(Protocol):
"""
Synchronous embedding function.
Same signature as litellm.embedding but without the `model` parameter
as this is already set in the model configuration.
"""
def __call__(
self,
*,
request_id: str | None = None,
input: list = [], # type: ignore # noqa: B006
# Optional params
dimensions: int | None = None,
encoding_format: str | None = None,
timeout: int = 600, # default to 10 minutes
# set api_base, api_version, api_key
api_base: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
api_type: str | None = None,
caching: bool = False,
user: str | None = None,
**kwargs: Any,
) -> EmbeddingResponse:
"""Embedding function."""
...
@runtime_checkable
class AFixedModelEmbedding(Protocol):
"""
Asynchronous embedding function.
Same signature as litellm.embedding but without the `model` parameter
as this is already set in the model configuration.
"""
async def __call__(
self,
*,
request_id: str | None = None,
input: list = [], # type: ignore # noqa: B006
# Optional params
dimensions: int | None = None,
encoding_format: str | None = None,
timeout: int = 600, # default to 10 minutes
# set api_base, api_version, api_key
api_base: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
api_type: str | None = None,
caching: bool = False,
user: str | None = None,
**kwargs: Any,
) -> EmbeddingResponse:
"""Embedding function."""
...
@runtime_checkable
class LitellmRequestFunc(Protocol):
"""
Synchronous request function.
Represents either a chat completion or embedding function.
"""
def __call__(self, /, **kwargs: Any) -> Any:
"""Request function."""
...
@runtime_checkable
class AsyncLitellmRequestFunc(Protocol):
"""
Asynchronous request function.
Represents either a chat completion or embedding function.
"""
async def __call__(self, /, **kwargs: Any) -> Any:
"""Request function."""
...

View File

@ -5,8 +5,6 @@
from pathlib import Path
import graphrag.config.defaults as defs
from graphrag.index.utils.tokens import num_tokens_from_string
from graphrag.prompt_tune.template.extract_graph import (
EXAMPLE_EXTRACTION_TEMPLATE,
GRAPH_EXTRACTION_JSON_PROMPT,
@ -14,6 +12,8 @@ from graphrag.prompt_tune.template.extract_graph import (
UNTYPED_EXAMPLE_EXTRACTION_TEMPLATE,
UNTYPED_GRAPH_EXTRACTION_PROMPT,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
EXTRACT_GRAPH_FILENAME = "extract_graph.txt"
@ -24,7 +24,7 @@ def create_extract_graph_prompt(
examples: list[str],
language: str,
max_token_count: int,
encoding_model: str = defs.ENCODING_MODEL,
tokenizer: Tokenizer | None = None,
json_mode: bool = False,
output_path: Path | None = None,
min_examples_required: int = 2,
@ -38,7 +38,7 @@ def create_extract_graph_prompt(
- docs (list[str]): The list of documents to extract entities from
- examples (list[str]): The list of examples to use for entity extraction
- language (str): The language of the inputs and outputs
- encoding_model (str): The name of the model to use for token counting
- tokenizer (Tokenizer): The tokenizer to use for encoding and decoding text.
- max_token_count (int): The maximum number of tokens to use for the prompt
- json_mode (bool): Whether to use JSON mode for the prompt. Default is False
- output_path (Path | None): The path to write the prompt to. Default is None.
@ -56,10 +56,12 @@ def create_extract_graph_prompt(
if isinstance(entity_types, list):
entity_types = ", ".join(map(str, entity_types))
tokenizer = tokenizer or get_tokenizer()
tokens_left = (
max_token_count
- num_tokens_from_string(prompt, encoding_name=encoding_model)
- num_tokens_from_string(entity_types, encoding_name=encoding_model)
- tokenizer.num_tokens(prompt)
- tokenizer.num_tokens(entity_types)
if entity_types
else 0
)
@ -79,9 +81,7 @@ def create_extract_graph_prompt(
)
)
example_tokens = num_tokens_from_string(
example_formatted, encoding_name=encoding_model
)
example_tokens = tokenizer.num_tokens(example_formatted)
# Ensure at least three examples are included
if i >= min_examples_required and example_tokens > tokens_left:

View File

@ -8,11 +8,11 @@ import random
from typing import Any, cast
import pandas as pd
import tiktoken
from graphrag.data_model.community_report import CommunityReport
from graphrag.data_model.entity import Entity
from graphrag.query.llm.text_utils import num_tokens
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
@ -24,7 +24,7 @@ NO_COMMUNITY_RECORDS_WARNING: str = (
def build_community_context(
community_reports: list[CommunityReport],
entities: list[Entity] | None = None,
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
use_community_summary: bool = True,
column_delimiter: str = "|",
shuffle_data: bool = True,
@ -46,6 +46,7 @@ def build_community_context(
The calculated weight is added as an attribute to the community reports and added to the context data table.
"""
tokenizer = tokenizer or get_tokenizer()
def _is_included(report: CommunityReport) -> bool:
return report.rank is not None and report.rank >= min_community_rank
@ -125,7 +126,7 @@ def build_community_context(
batch_text = (
f"-----{context_name}-----" + "\n" + column_delimiter.join(header) + "\n"
)
batch_tokens = num_tokens(batch_text, token_encoder)
batch_tokens = tokenizer.num_tokens(batch_text)
batch_records = []
def _cut_batch() -> None:
@ -152,7 +153,7 @@ def build_community_context(
for report in selected_reports:
new_context_text, new_context = _report_context_text(report, attributes)
new_tokens = num_tokens(new_context_text, token_encoder)
new_tokens = tokenizer.num_tokens(new_context_text)
if batch_tokens + new_tokens > max_context_tokens:
# add the current batch to the context data and start a new batch if we are in multi-batch mode

View File

@ -7,9 +7,9 @@ from dataclasses import dataclass
from enum import Enum
import pandas as pd
import tiktoken
from graphrag.query.llm.text_utils import num_tokens
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
"""
Enum for conversation roles
@ -148,7 +148,7 @@ class ConversationHistory:
def build_context(
self,
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
include_user_turns_only: bool = True,
max_qa_turns: int | None = 5,
max_context_tokens: int = 8000,
@ -168,6 +168,7 @@ class ConversationHistory:
context_name: Name of the context, default is "Conversation History".
"""
tokenizer = tokenizer or get_tokenizer()
qa_turns = self.to_qa_turns()
if include_user_turns_only:
qa_turns = [
@ -202,7 +203,7 @@ class ConversationHistory:
context_df = pd.DataFrame(turn_list)
context_text = header + context_df.to_csv(sep=column_delimiter, index=False)
if num_tokens(context_text, token_encoder) > max_context_tokens:
if tokenizer.num_tokens(context_text) > max_context_tokens:
break
current_context_df = context_df

View File

@ -10,13 +10,12 @@ from copy import deepcopy
from time import time
from typing import Any
import tiktoken
from graphrag.data_model.community import Community
from graphrag.data_model.community_report import CommunityReport
from graphrag.language_model.protocol.base import ChatModel
from graphrag.query.context_builder.rate_prompt import RATE_QUERY
from graphrag.query.context_builder.rate_relevancy import rate_relevancy
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
@ -32,7 +31,7 @@ class DynamicCommunitySelection:
community_reports: list[CommunityReport],
communities: list[Community],
model: ChatModel,
token_encoder: tiktoken.Encoding,
tokenizer: Tokenizer,
rate_query: str = RATE_QUERY,
use_summary: bool = False,
threshold: int = 1,
@ -43,7 +42,7 @@ class DynamicCommunitySelection:
model_params: dict[str, Any] | None = None,
):
self.model = model
self.token_encoder = token_encoder
self.tokenizer = tokenizer
self.rate_query = rate_query
self.num_repeats = num_repeats
self.use_summary = use_summary
@ -97,7 +96,7 @@ class DynamicCommunitySelection:
else self.reports[community].full_content
),
model=self.model,
token_encoder=self.token_encoder,
tokenizer=self.tokenizer,
rate_query=self.rate_query,
num_repeats=self.num_repeats,
semaphore=self.semaphore,

View File

@ -7,7 +7,6 @@ from collections import defaultdict
from typing import Any, cast
import pandas as pd
import tiktoken
from graphrag.data_model.covariate import Covariate
from graphrag.data_model.entity import Entity
@ -24,12 +23,13 @@ from graphrag.query.input.retrieval.relationships import (
get_out_network_relationships,
to_relationship_dataframe,
)
from graphrag.query.llm.text_utils import num_tokens
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
def build_entity_context(
selected_entities: list[Entity],
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
max_context_tokens: int = 8000,
include_entity_rank: bool = True,
rank_description: str = "number of relationships",
@ -37,6 +37,8 @@ def build_entity_context(
context_name="Entities",
) -> tuple[str, pd.DataFrame]:
"""Prepare entity data table as context data for system prompt."""
tokenizer = tokenizer or get_tokenizer()
if len(selected_entities) == 0:
return "", pd.DataFrame()
@ -52,7 +54,7 @@ def build_entity_context(
)
header.extend(attribute_cols)
current_context_text += column_delimiter.join(header) + "\n"
current_tokens = num_tokens(current_context_text, token_encoder)
current_tokens = tokenizer.num_tokens(current_context_text)
all_context_records = [header]
for entity in selected_entities:
@ -71,7 +73,7 @@ def build_entity_context(
)
new_context.append(field_value)
new_context_text = column_delimiter.join(new_context) + "\n"
new_tokens = num_tokens(new_context_text, token_encoder)
new_tokens = tokenizer.num_tokens(new_context_text)
if current_tokens + new_tokens > max_context_tokens:
break
current_context_text += new_context_text
@ -91,12 +93,13 @@ def build_entity_context(
def build_covariates_context(
selected_entities: list[Entity],
covariates: list[Covariate],
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
max_context_tokens: int = 8000,
column_delimiter: str = "|",
context_name: str = "Covariates",
) -> tuple[str, pd.DataFrame]:
"""Prepare covariate data tables as context data for system prompt."""
tokenizer = tokenizer or get_tokenizer()
# create an empty list of covariates
if len(selected_entities) == 0 or len(covariates) == 0:
return "", pd.DataFrame()
@ -113,7 +116,7 @@ def build_covariates_context(
attribute_cols = list(attributes.keys()) if len(covariates) > 0 else []
header.extend(attribute_cols)
current_context_text += column_delimiter.join(header) + "\n"
current_tokens = num_tokens(current_context_text, token_encoder)
current_tokens = tokenizer.num_tokens(current_context_text)
all_context_records = [header]
for entity in selected_entities:
@ -135,7 +138,7 @@ def build_covariates_context(
new_context.append(field_value)
new_context_text = column_delimiter.join(new_context) + "\n"
new_tokens = num_tokens(new_context_text, token_encoder)
new_tokens = tokenizer.num_tokens(new_context_text)
if current_tokens + new_tokens > max_context_tokens:
break
current_context_text += new_context_text
@ -155,7 +158,7 @@ def build_covariates_context(
def build_relationship_context(
selected_entities: list[Entity],
relationships: list[Relationship],
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
include_relationship_weight: bool = False,
max_context_tokens: int = 8000,
top_k_relationships: int = 10,
@ -164,6 +167,7 @@ def build_relationship_context(
context_name: str = "Relationships",
) -> tuple[str, pd.DataFrame]:
"""Prepare relationship data tables as context data for system prompt."""
tokenizer = tokenizer or get_tokenizer()
selected_relationships = _filter_relationships(
selected_entities=selected_entities,
relationships=relationships,
@ -188,7 +192,7 @@ def build_relationship_context(
header.extend(attribute_cols)
current_context_text += column_delimiter.join(header) + "\n"
current_tokens = num_tokens(current_context_text, token_encoder)
current_tokens = tokenizer.num_tokens(current_context_text)
all_context_records = [header]
for rel in selected_relationships:
@ -208,7 +212,7 @@ def build_relationship_context(
)
new_context.append(field_value)
new_context_text = column_delimiter.join(new_context) + "\n"
new_tokens = num_tokens(new_context_text, token_encoder)
new_tokens = tokenizer.num_tokens(new_context_text)
if current_tokens + new_tokens > max_context_tokens:
break
current_context_text += new_context_text

View File

@ -9,11 +9,11 @@ from contextlib import nullcontext
from typing import Any
import numpy as np
import tiktoken
from graphrag.language_model.protocol.base import ChatModel
from graphrag.query.context_builder.rate_prompt import RATE_QUERY
from graphrag.query.llm.text_utils import num_tokens, try_parse_json_object
from graphrag.query.llm.text_utils import try_parse_json_object
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
@ -22,7 +22,7 @@ async def rate_relevancy(
query: str,
description: str,
model: ChatModel,
token_encoder: tiktoken.Encoding,
tokenizer: Tokenizer,
rate_query: str = RATE_QUERY,
num_repeats: int = 1,
semaphore: asyncio.Semaphore | None = None,
@ -36,7 +36,7 @@ async def rate_relevancy(
description: the community description to rate, it can be the community
title, summary, or the full content.
llm: LLM model to use for rating
token_encoder: token encoder
tokenizer: tokenizer
num_repeats: number of times to repeat the rating process for the same community (default: 1)
model_params: additional arguments to pass to the LLM model
semaphore: asyncio.Semaphore to limit the number of concurrent LLM calls (default: None)
@ -63,8 +63,8 @@ async def rate_relevancy(
logger.warning("Error parsing json response, defaulting to rating 1")
ratings.append(1)
llm_calls += 1
prompt_tokens += num_tokens(messages[0]["content"], token_encoder)
output_tokens += num_tokens(response, token_encoder)
prompt_tokens += tokenizer.num_tokens(messages[0]["content"])
output_tokens += tokenizer.num_tokens(response)
# select the decision with the most votes
options, counts = np.unique(ratings, return_counts=True)
rating = int(options[np.argmax(counts)])

View File

@ -7,11 +7,11 @@ import random
from typing import Any, cast
import pandas as pd
import tiktoken
from graphrag.data_model.relationship import Relationship
from graphrag.data_model.text_unit import TextUnit
from graphrag.query.llm.text_utils import num_tokens
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
"""
Contain util functions to build text unit context for the search's system prompt
@ -20,7 +20,7 @@ Contain util functions to build text unit context for the search's system prompt
def build_text_unit_context(
text_units: list[TextUnit],
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
column_delimiter: str = "|",
shuffle_data: bool = True,
max_context_tokens: int = 8000,
@ -28,6 +28,7 @@ def build_text_unit_context(
random_state: int = 86,
) -> tuple[str, dict[str, pd.DataFrame]]:
"""Prepare text-unit data table as context data for system prompt."""
tokenizer = tokenizer or get_tokenizer()
if text_units is None or len(text_units) == 0:
return ("", {})
@ -47,7 +48,7 @@ def build_text_unit_context(
header.extend(attribute_cols)
current_context_text += column_delimiter.join(header) + "\n"
current_tokens = num_tokens(current_context_text, token_encoder)
current_tokens = tokenizer.num_tokens(current_context_text)
all_context_records = [header]
for unit in text_units:
@ -60,7 +61,7 @@ def build_text_unit_context(
],
]
new_context_text = column_delimiter.join(new_context) + "\n"
new_tokens = num_tokens(new_context_text, token_encoder)
new_tokens = tokenizer.num_tokens(new_context_text)
if current_tokens + new_tokens > max_context_tokens:
break

View File

@ -3,8 +3,6 @@
"""Query Factory methods to support CLI."""
import tiktoken
from graphrag.callbacks.query_callbacks import QueryCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.community import Community
@ -34,6 +32,7 @@ from graphrag.query.structured_search.local_search.mixed_context import (
LocalSearchMixedContext,
)
from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.vector_stores.base import BaseVectorStore
@ -68,7 +67,7 @@ def get_local_search_engine(
config=embedding_settings,
)
token_encoder = tiktoken.get_encoding(model_settings.encoding_model)
tokenizer = get_tokenizer(model_config=model_settings)
ls_config = config.local_search
@ -86,9 +85,9 @@ def get_local_search_engine(
entity_text_embeddings=description_embedding_store,
embedding_vectorstore_key=EntityVectorStoreKey.ID, # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE
text_embedder=embedding_model,
token_encoder=token_encoder,
tokenizer=tokenizer,
),
token_encoder=token_encoder,
tokenizer=tokenizer,
model_params=model_params,
context_builder_params={
"text_unit_prop": ls_config.text_unit_prop,
@ -135,7 +134,7 @@ def get_global_search_engine(
model_params = get_openai_model_parameters_from_config(model_settings)
# Here we get encoding based on specified encoding name
token_encoder = tiktoken.get_encoding(model_settings.encoding_model)
tokenizer = get_tokenizer(model_config=model_settings)
gs_config = config.global_search
dynamic_community_selection_kwargs = {}
@ -144,7 +143,7 @@ def get_global_search_engine(
dynamic_community_selection_kwargs.update({
"model": model,
"token_encoder": token_encoder,
"tokenizer": tokenizer,
"keep_parent": gs_config.dynamic_search_keep_parent,
"num_repeats": gs_config.dynamic_search_num_repeats,
"use_summary": gs_config.dynamic_search_use_summary,
@ -163,11 +162,11 @@ def get_global_search_engine(
community_reports=reports,
communities=communities,
entities=entities,
token_encoder=token_encoder,
tokenizer=tokenizer,
dynamic_community_selection=dynamic_community_selection,
dynamic_community_selection_kwargs=dynamic_community_selection_kwargs,
),
token_encoder=token_encoder,
tokenizer=tokenizer,
max_data_tokens=gs_config.data_max_tokens,
map_llm_params={**model_params},
reduce_llm_params={**model_params},
@ -226,7 +225,7 @@ def get_drift_search_engine(
config=embedding_model_settings,
)
token_encoder = tiktoken.get_encoding(chat_model_settings.encoding_model)
tokenizer = get_tokenizer(model_config=chat_model_settings)
return DRIFTSearch(
model=chat_model,
@ -243,7 +242,7 @@ def get_drift_search_engine(
config=config.drift_search,
response_type=response_type,
),
token_encoder=token_encoder,
tokenizer=tokenizer,
callbacks=callbacks,
)
@ -277,7 +276,7 @@ def get_basic_search_engine(
config=embedding_model_settings,
)
token_encoder = tiktoken.get_encoding(chat_model_settings.encoding_model)
tokenizer = get_tokenizer(model_config=chat_model_settings)
bs_config = config.basic_search
@ -291,9 +290,9 @@ def get_basic_search_engine(
text_embedder=embedding_model,
text_unit_embeddings=text_unit_embeddings,
text_units=text_units,
token_encoder=token_encoder,
tokenizer=tokenizer,
),
token_encoder=token_encoder,
tokenizer=tokenizer,
model_params=model_params,
context_builder_params={
"embedding_vectorstore_key": "id",

View File

@ -9,7 +9,6 @@ from dataclasses import dataclass
from typing import Any, Generic, TypeVar
import pandas as pd
import tiktoken
from graphrag.language_model.protocol.base import ChatModel
from graphrag.query.context_builder.builders import (
@ -21,6 +20,8 @@ from graphrag.query.context_builder.builders import (
from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
@dataclass
@ -58,13 +59,13 @@ class BaseSearch(ABC, Generic[T]):
self,
model: ChatModel,
context_builder: T,
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
model_params: dict[str, Any] | None = None,
context_builder_params: dict[str, Any] | None = None,
):
self.model = model
self.context_builder = context_builder
self.token_encoder = token_encoder
self.tokenizer = tokenizer or get_tokenizer()
self.model_params = model_params or {}
self.context_builder_params = context_builder_params or {}

View File

@ -7,7 +7,6 @@ import logging
from typing import cast
import pandas as pd
import tiktoken
from graphrag.data_model.text_unit import TextUnit
from graphrag.language_model.protocol.base import EmbeddingModel
@ -16,7 +15,8 @@ from graphrag.query.context_builder.builders import (
ContextBuilderResult,
)
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.query.llm.text_utils import num_tokens
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.vector_stores.base import BaseVectorStore
logger = logging.getLogger(__name__)
@ -30,11 +30,11 @@ class BasicSearchContext(BasicContextBuilder):
text_embedder: EmbeddingModel,
text_unit_embeddings: BaseVectorStore,
text_units: list[TextUnit] | None = None,
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
embedding_vectorstore_key: str = "id",
):
self.text_embedder = text_embedder
self.token_encoder = token_encoder
self.tokenizer = tokenizer or get_tokenizer()
self.text_units = text_units
self.text_unit_embeddings = text_unit_embeddings
self.embedding_vectorstore_key = embedding_vectorstore_key
@ -76,12 +76,12 @@ class BasicSearchContext(BasicContextBuilder):
# add these related text chunks into context until we fill up the context window
current_tokens = 0
text_ids = []
current_tokens = num_tokens(
text_id_col + column_delimiter + text_col + "\n", self.token_encoder
current_tokens = len(
self.tokenizer.encode(text_id_col + column_delimiter + text_col + "\n")
)
for i, row in related_text_df.iterrows():
text = row[text_id_col] + column_delimiter + row[text_col] + "\n"
tokens = num_tokens(text, self.token_encoder)
tokens = len(self.tokenizer.encode(text))
if current_tokens + tokens > max_context_tokens:
msg = f"Reached token limit: {current_tokens + tokens}. Reverting to previous context state"
logger.warning(msg)

View File

@ -8,8 +8,6 @@ import time
from collections.abc import AsyncGenerator
from typing import Any
import tiktoken
from graphrag.callbacks.query_callbacks import QueryCallbacks
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.query.basic_search_system_prompt import (
@ -17,8 +15,8 @@ from graphrag.prompts.query.basic_search_system_prompt import (
)
from graphrag.query.context_builder.builders import BasicContextBuilder
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
"""
@ -33,7 +31,7 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
self,
model: ChatModel,
context_builder: BasicContextBuilder,
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
system_prompt: str | None = None,
response_type: str = "multiple paragraphs",
callbacks: list[QueryCallbacks] | None = None,
@ -43,7 +41,7 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
super().__init__(
model=model,
context_builder=context_builder,
token_encoder=token_encoder,
tokenizer=tokenizer,
model_params=model_params,
context_builder_params=context_builder_params or {},
)
@ -94,8 +92,8 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
response += chunk
llm_calls["response"] = 1
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
output_tokens["response"] = num_tokens(response, self.token_encoder)
prompt_tokens["response"] = len(self.tokenizer.encode(search_prompt))
output_tokens["response"] = len(self.tokenizer.encode(response))
for callback in self.callbacks:
callback.on_context(context_result.context_records)
@ -106,7 +104,7 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
prompt_tokens=len(self.tokenizer.encode(search_prompt)),
output_tokens=sum(output_tokens.values()),
llm_calls_categories=llm_calls,
prompt_tokens_categories=prompt_tokens,
@ -121,7 +119,7 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
prompt_tokens=len(self.tokenizer.encode(search_prompt)),
output_tokens=0,
llm_calls_categories=llm_calls,
prompt_tokens_categories=prompt_tokens,

View File

@ -9,7 +9,6 @@ from typing import Any
import numpy as np
import pandas as pd
import tiktoken
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.data_model.community_report import CommunityReport
@ -28,6 +27,8 @@ from graphrag.query.structured_search.drift_search.primer import PrimerQueryProc
from graphrag.query.structured_search.local_search.mixed_context import (
LocalSearchMixedContext,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.vector_stores.base import BaseVectorStore
logger = logging.getLogger(__name__)
@ -46,7 +47,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
reports: list[CommunityReport] | None = None,
relationships: list[Relationship] | None = None,
covariates: dict[str, list[Covariate]] | None = None,
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
embedding_vectorstore_key: str = EntityVectorStoreKey.ID,
config: DRIFTSearchConfig | None = None,
local_system_prompt: str | None = None,
@ -58,7 +59,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
self.config = config or DRIFTSearchConfig()
self.model = model
self.text_embedder = text_embedder
self.token_encoder = token_encoder
self.tokenizer = tokenizer or get_tokenizer()
self.local_system_prompt = local_system_prompt or DRIFT_LOCAL_SYSTEM_PROMPT
self.reduce_system_prompt = reduce_system_prompt or DRIFT_REDUCE_PROMPT
@ -93,7 +94,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
entity_text_embeddings=self.entity_text_embeddings,
embedding_vectorstore_key=self.embedding_vectorstore_key,
text_embedder=self.text_embedder,
token_encoder=self.token_encoder,
tokenizer=self.tokenizer,
)
@staticmethod
@ -192,7 +193,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
query_processor = PrimerQueryProcessor(
chat_model=self.model,
text_embedder=self.text_embedder,
token_encoder=self.token_encoder,
tokenizer=self.tokenizer,
reports=self.reports,
)

View File

@ -10,7 +10,6 @@ import time
import numpy as np
import pandas as pd
import tiktoken
from tqdm.asyncio import tqdm_asyncio
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
@ -19,8 +18,9 @@ from graphrag.language_model.protocol.base import ChatModel, EmbeddingModel
from graphrag.prompts.query.drift_search_system_prompt import (
DRIFT_PRIMER_PROMPT,
)
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import SearchResult
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
@ -33,7 +33,7 @@ class PrimerQueryProcessor:
chat_model: ChatModel,
text_embedder: EmbeddingModel,
reports: list[CommunityReport],
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
):
"""
Initialize the PrimerQueryProcessor.
@ -42,11 +42,11 @@ class PrimerQueryProcessor:
chat_llm (ChatOpenAI): The language model used to process the query.
text_embedder (BaseTextEmbedding): The text embedding model.
reports (list[CommunityReport]): List of community reports.
token_encoder (tiktoken.Encoding, optional): Token encoder for token counting.
tokenizer (Tokenizer, optional): Token encoder for token counting.
"""
self.chat_model = chat_model
self.text_embedder = text_embedder
self.token_encoder = token_encoder
self.tokenizer = tokenizer or get_tokenizer()
self.reports = reports
async def expand_query(self, query: str) -> tuple[str, dict[str, int]]:
@ -70,8 +70,8 @@ class PrimerQueryProcessor:
model_response = await self.chat_model.achat(prompt)
text = model_response.output.content
prompt_tokens = num_tokens(prompt, self.token_encoder)
output_tokens = num_tokens(text, self.token_encoder)
prompt_tokens = len(self.tokenizer.encode(prompt))
output_tokens = len(self.tokenizer.encode(text))
token_ct = {
"llm_calls": 1,
"prompt_tokens": prompt_tokens,
@ -105,7 +105,7 @@ class DRIFTPrimer:
self,
config: DRIFTSearchConfig,
chat_model: ChatModel,
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
):
"""
Initialize the DRIFTPrimer.
@ -113,11 +113,11 @@ class DRIFTPrimer:
Args:
config (DRIFTSearchConfig): Configuration settings for DRIFT search.
chat_llm (ChatOpenAI): The language model used for searching.
token_encoder (tiktoken.Encoding, optional): Token encoder for managing tokens.
tokenizer (Tokenizer, optional): Tokenizer for managing tokens.
"""
self.chat_model = chat_model
self.config = config
self.token_encoder = token_encoder
self.tokenizer = tokenizer or get_tokenizer()
async def decompose_query(
self, query: str, reports: pd.DataFrame
@ -144,8 +144,8 @@ class DRIFTPrimer:
token_ct = {
"llm_calls": 1,
"prompt_tokens": num_tokens(prompt, self.token_encoder),
"output_tokens": num_tokens(response, self.token_encoder),
"prompt_tokens": len(self.tokenizer.encode(prompt)),
"output_tokens": len(self.tokenizer.encode(response)),
}
return parsed_response, token_ct

View File

@ -8,7 +8,6 @@ import time
from collections.abc import AsyncGenerator
from typing import Any
import tiktoken
from tqdm.asyncio import tqdm_asyncio
from graphrag.callbacks.query_callbacks import QueryCallbacks
@ -18,7 +17,6 @@ from graphrag.language_model.providers.fnllm.utils import (
)
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.query.structured_search.drift_search.action import DriftAction
from graphrag.query.structured_search.drift_search.drift_context import (
@ -27,6 +25,8 @@ from graphrag.query.structured_search.drift_search.drift_context import (
from graphrag.query.structured_search.drift_search.primer import DRIFTPrimer
from graphrag.query.structured_search.drift_search.state import QueryState
from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
@ -38,7 +38,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
self,
model: ChatModel,
context_builder: DRIFTSearchContextBuilder,
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
query_state: QueryState | None = None,
callbacks: list[QueryCallbacks] | None = None,
):
@ -49,18 +49,18 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
llm (ChatOpenAI): The language model used for searching.
context_builder (DRIFTSearchContextBuilder): Builder for search context.
config (DRIFTSearchConfig, optional): Configuration settings for DRIFTSearch.
token_encoder (tiktoken.Encoding, optional): Token encoder for managing tokens.
tokenizer (Tokenizer, optional): Token encoder for managing tokens.
query_state (QueryState, optional): State of the current search query.
"""
super().__init__(model, context_builder, token_encoder)
super().__init__(model, context_builder, tokenizer)
self.context_builder = context_builder
self.token_encoder = token_encoder
self.tokenizer = tokenizer or get_tokenizer()
self.query_state = query_state or QueryState()
self.primer = DRIFTPrimer(
config=self.context_builder.config,
chat_model=model,
token_encoder=token_encoder,
tokenizer=self.tokenizer,
)
self.callbacks = callbacks or []
self.local_search = self.init_local_search()
@ -100,7 +100,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
model=self.model,
system_prompt=self.context_builder.local_system_prompt,
context_builder=self.context_builder.local_mixed_context,
token_encoder=self.token_encoder,
tokenizer=self.tokenizer,
model_params=model_params,
context_builder_params=local_context_params,
response_type="multiple paragraphs",
@ -392,10 +392,10 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
reduced_response = model_response.output.content
llm_calls["reduce"] = 1
prompt_tokens["reduce"] = num_tokens(
search_prompt, self.token_encoder
) + num_tokens(query, self.token_encoder)
output_tokens["reduce"] = num_tokens(reduced_response, self.token_encoder)
prompt_tokens["reduce"] = len(self.tokenizer.encode(search_prompt)) + len(
self.tokenizer.encode(query)
)
output_tokens["reduce"] = len(self.tokenizer.encode(reduced_response))
return reduced_response

View File

@ -5,8 +5,6 @@
from typing import Any
import tiktoken
from graphrag.data_model.community import Community
from graphrag.data_model.community_report import CommunityReport
from graphrag.data_model.entity import Entity
@ -21,6 +19,8 @@ from graphrag.query.context_builder.dynamic_community_selection import (
DynamicCommunitySelection,
)
from graphrag.query.structured_search.base import GlobalContextBuilder
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
class GlobalCommunityContext(GlobalContextBuilder):
@ -31,14 +31,14 @@ class GlobalCommunityContext(GlobalContextBuilder):
community_reports: list[CommunityReport],
communities: list[Community],
entities: list[Entity] | None = None,
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
dynamic_community_selection: bool = False,
dynamic_community_selection_kwargs: dict[str, Any] | None = None,
random_state: int = 86,
):
self.community_reports = community_reports
self.entities = entities
self.token_encoder = token_encoder
self.tokenizer = tokenizer or get_tokenizer()
self.dynamic_community_selection = None
if dynamic_community_selection and isinstance(
dynamic_community_selection_kwargs, dict
@ -47,7 +47,7 @@ class GlobalCommunityContext(GlobalContextBuilder):
community_reports=community_reports,
communities=communities,
model=dynamic_community_selection_kwargs.pop("model"),
token_encoder=dynamic_community_selection_kwargs.pop("token_encoder"),
tokenizer=dynamic_community_selection_kwargs.pop("tokenizer"),
**dynamic_community_selection_kwargs,
)
self.random_state = random_state
@ -103,7 +103,7 @@ class GlobalCommunityContext(GlobalContextBuilder):
community_context, community_context_data = build_community_context(
community_reports=community_reports,
entities=self.entities,
token_encoder=self.token_encoder,
tokenizer=self.tokenizer,
use_community_summary=use_community_summary,
column_delimiter=column_delimiter,
shuffle_data=shuffle_data,

View File

@ -12,7 +12,6 @@ from dataclasses import dataclass
from typing import Any
import pandas as pd
import tiktoken
from graphrag.callbacks.query_callbacks import QueryCallbacks
from graphrag.language_model.protocol.base import ChatModel
@ -30,8 +29,9 @@ from graphrag.query.context_builder.builders import GlobalContextBuilder
from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
)
from graphrag.query.llm.text_utils import num_tokens, try_parse_json_object
from graphrag.query.llm.text_utils import try_parse_json_object
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
@ -52,7 +52,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
self,
model: ChatModel,
context_builder: GlobalContextBuilder,
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
map_system_prompt: str | None = None,
reduce_system_prompt: str | None = None,
response_type: str = "multiple paragraphs",
@ -71,7 +71,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
super().__init__(
model=model,
context_builder=context_builder,
token_encoder=token_encoder,
tokenizer=tokenizer,
context_builder_params=context_builder_params,
)
self.map_system_prompt = map_system_prompt or MAP_SYSTEM_PROMPT
@ -247,8 +247,8 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
context_text=context_data,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=num_tokens(search_response, self.token_encoder),
prompt_tokens=len(self.tokenizer.encode(search_prompt)),
output_tokens=len(self.tokenizer.encode(search_response)),
)
except Exception:
@ -259,7 +259,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
context_text=context_data,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
prompt_tokens=len(self.tokenizer.encode(search_prompt)),
output_tokens=0,
)
@ -361,13 +361,12 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
formatted_response_data.append(point["answer"]) # type: ignore
formatted_response_text = "\n".join(formatted_response_data)
if (
total_tokens
+ num_tokens(formatted_response_text, self.token_encoder)
total_tokens + len(self.tokenizer.encode(formatted_response_text))
> self.max_data_tokens
):
break
data.append(formatted_response_text)
total_tokens += num_tokens(formatted_response_text, self.token_encoder)
total_tokens += len(self.tokenizer.encode(formatted_response_text))
text_data = "\n\n".join(data)
search_prompt = self.reduce_system_prompt.format(
@ -398,8 +397,8 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
context_text=text_data,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=num_tokens(search_response, self.token_encoder),
prompt_tokens=len(self.tokenizer.encode(search_prompt)),
output_tokens=len(self.tokenizer.encode(search_response)),
)
except Exception:
logger.exception("Exception in reduce_response")
@ -409,7 +408,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
context_text=text_data,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
prompt_tokens=len(self.tokenizer.encode(search_prompt)),
output_tokens=0,
)
@ -467,12 +466,12 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
]
formatted_response_text = "\n".join(formatted_response_data)
if (
total_tokens + num_tokens(formatted_response_text, self.token_encoder)
total_tokens + len(self.tokenizer.encode(formatted_response_text))
> self.max_data_tokens
):
break
data.append(formatted_response_text)
total_tokens += num_tokens(formatted_response_text, self.token_encoder)
total_tokens += len(self.tokenizer.encode(formatted_response_text))
text_data = "\n\n".join(data)
search_prompt = self.reduce_system_prompt.format(

View File

@ -7,7 +7,6 @@ from copy import deepcopy
from typing import Any
import pandas as pd
import tiktoken
from graphrag.data_model.community_report import CommunityReport
from graphrag.data_model.covariate import Covariate
@ -40,8 +39,9 @@ from graphrag.query.input.retrieval.community_reports import (
get_candidate_communities,
)
from graphrag.query.input.retrieval.text_units import get_candidate_text_units
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import LocalContextBuilder
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.vector_stores.base import BaseVectorStore
logger = logging.getLogger(__name__)
@ -59,7 +59,7 @@ class LocalSearchMixedContext(LocalContextBuilder):
community_reports: list[CommunityReport] | None = None,
relationships: list[Relationship] | None = None,
covariates: dict[str, list[Covariate]] | None = None,
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
embedding_vectorstore_key: str = EntityVectorStoreKey.ID,
):
if community_reports is None:
@ -81,7 +81,7 @@ class LocalSearchMixedContext(LocalContextBuilder):
self.covariates = covariates
self.entity_text_embeddings = entity_text_embeddings
self.text_embedder = text_embedder
self.token_encoder = token_encoder
self.tokenizer = tokenizer or get_tokenizer()
self.embedding_vectorstore_key = embedding_vectorstore_key
def filter_by_entity_keys(self, entity_keys: list[int] | list[str]):
@ -167,8 +167,8 @@ class LocalSearchMixedContext(LocalContextBuilder):
if conversation_history_context.strip() != "":
final_context.append(conversation_history_context)
final_context_data = conversation_history_context_data
max_context_tokens = max_context_tokens - num_tokens(
conversation_history_context, self.token_encoder
max_context_tokens = max_context_tokens - len(
self.tokenizer.encode(conversation_history_context)
)
# build community context
@ -264,7 +264,7 @@ class LocalSearchMixedContext(LocalContextBuilder):
context_text, context_data = build_community_context(
community_reports=selected_communities,
token_encoder=self.token_encoder,
tokenizer=self.tokenizer,
use_community_summary=use_community_summary,
column_delimiter=column_delimiter,
shuffle_data=False,
@ -344,7 +344,7 @@ class LocalSearchMixedContext(LocalContextBuilder):
context_text, context_data = build_text_unit_context(
text_units=selected_text_units,
token_encoder=self.token_encoder,
tokenizer=self.tokenizer,
max_context_tokens=max_context_tokens,
shuffle_data=False,
context_name=context_name,
@ -390,14 +390,14 @@ class LocalSearchMixedContext(LocalContextBuilder):
# build entity context
entity_context, entity_context_data = build_entity_context(
selected_entities=selected_entities,
token_encoder=self.token_encoder,
tokenizer=self.tokenizer,
max_context_tokens=max_context_tokens,
column_delimiter=column_delimiter,
include_entity_rank=include_entity_rank,
rank_description=rank_description,
context_name="Entities",
)
entity_tokens = num_tokens(entity_context, self.token_encoder)
entity_tokens = len(self.tokenizer.encode(entity_context))
# build relationship-covariate context
added_entities = []
@ -417,7 +417,7 @@ class LocalSearchMixedContext(LocalContextBuilder):
) = build_relationship_context(
selected_entities=added_entities,
relationships=list(self.relationships.values()),
token_encoder=self.token_encoder,
tokenizer=self.tokenizer,
max_context_tokens=max_context_tokens,
column_delimiter=column_delimiter,
top_k_relationships=top_k_relationships,
@ -427,8 +427,8 @@ class LocalSearchMixedContext(LocalContextBuilder):
)
current_context.append(relationship_context)
current_context_data["relationships"] = relationship_context_data
total_tokens = entity_tokens + num_tokens(
relationship_context, self.token_encoder
total_tokens = entity_tokens + len(
self.tokenizer.encode(relationship_context)
)
# build covariate context
@ -436,12 +436,12 @@ class LocalSearchMixedContext(LocalContextBuilder):
covariate_context, covariate_context_data = build_covariates_context(
selected_entities=added_entities,
covariates=self.covariates[covariate],
token_encoder=self.token_encoder,
tokenizer=self.tokenizer,
max_context_tokens=max_context_tokens,
column_delimiter=column_delimiter,
context_name=covariate,
)
total_tokens += num_tokens(covariate_context, self.token_encoder)
total_tokens += len(self.tokenizer.encode(covariate_context))
current_context.append(covariate_context)
current_context_data[covariate.lower()] = covariate_context_data

View File

@ -8,8 +8,6 @@ import time
from collections.abc import AsyncGenerator
from typing import Any
import tiktoken
from graphrag.callbacks.query_callbacks import QueryCallbacks
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.query.local_search_system_prompt import (
@ -19,8 +17,8 @@ from graphrag.query.context_builder.builders import LocalContextBuilder
from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
)
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
@ -32,7 +30,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
self,
model: ChatModel,
context_builder: LocalContextBuilder,
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
system_prompt: str | None = None,
response_type: str = "multiple paragraphs",
callbacks: list[QueryCallbacks] | None = None,
@ -42,7 +40,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
super().__init__(
model=model,
context_builder=context_builder,
token_encoder=token_encoder,
tokenizer=tokenizer,
model_params=model_params,
context_builder_params=context_builder_params or {},
)
@ -100,8 +98,8 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
callback.on_llm_new_token(response)
llm_calls["response"] = 1
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
output_tokens["response"] = num_tokens(full_response, self.token_encoder)
prompt_tokens["response"] = len(self.tokenizer.encode(search_prompt))
output_tokens["response"] = len(self.tokenizer.encode(full_response))
for callback in self.callbacks:
callback.on_context(context_result.context_records)
@ -127,7 +125,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
prompt_tokens=len(self.tokenizer.encode(search_prompt)),
output_tokens=0,
)

View File

@ -0,0 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""GraphRAG tokenizer."""

View File

@ -0,0 +1,45 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Get Tokenizer."""
from typing import TYPE_CHECKING
from graphrag.config.defaults import ENCODING_MODEL
from graphrag.tokenizer.litellm_tokenizer import LitellmTokenizer
from graphrag.tokenizer.tiktoken_tokenizer import TiktokenTokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
if TYPE_CHECKING:
from graphrag.config.models.language_model_config import LanguageModelConfig
def get_tokenizer(
model_config: "LanguageModelConfig | None" = None,
encoding_model: str = ENCODING_MODEL,
) -> Tokenizer:
"""
Get the tokenizer for the given model configuration or fallback to a tiktoken based tokenizer.
Args
----
model_config: LanguageModelConfig, optional
The model configuration. If not provided or model_config.encoding_model is manually set,
use a tiktoken based tokenizer. Otherwise, use a LitellmTokenizer based on the model name.
LiteLLM supports token encoding/decoding for the range of models it supports.
encoding_model: str, optional
A tiktoken encoding model to use if no model configuration is provided. Only used if a
model configuration is not provided.
Returns
-------
An instance of a Tokenizer.
"""
if model_config is not None:
if model_config.encoding_model.strip() != "":
# User has manually specified a tiktoken encoding model to use for the provided model configuration.
return TiktokenTokenizer(encoding_name=model_config.encoding_model)
return LitellmTokenizer(model_name=model_config.model)
return TiktokenTokenizer(encoding_name=encoding_model)

View File

@ -0,0 +1,47 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM Tokenizer."""
from litellm import decode, encode # type: ignore
from graphrag.tokenizer.tokenizer import Tokenizer
class LitellmTokenizer(Tokenizer):
"""LiteLLM Tokenizer."""
def __init__(self, model_name: str) -> None:
"""Initialize the LiteLLM Tokenizer.
Args
----
model_name (str): The name of the LiteLLM model to use for tokenization.
"""
self.model_name = model_name
def encode(self, text: str) -> list[int]:
"""Encode the given text into a list of tokens.
Args
----
text (str): The input text to encode.
Returns
-------
list[int]: A list of tokens representing the encoded text.
"""
return encode(model=self.model_name, text=text)
def decode(self, tokens: list[int]) -> str:
"""Decode a list of tokens back into a string.
Args
----
tokens (list[int]): A list of tokens to decode.
Returns
-------
str: The decoded string from the list of tokens.
"""
return decode(model=self.model_name, tokens=tokens)

View File

@ -0,0 +1,47 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Tiktoken Tokenizer."""
import tiktoken
from graphrag.tokenizer.tokenizer import Tokenizer
class TiktokenTokenizer(Tokenizer):
"""Tiktoken Tokenizer."""
def __init__(self, encoding_name: str) -> None:
"""Initialize the Tiktoken Tokenizer.
Args
----
encoding_name (str): The name of the Tiktoken encoding to use for tokenization.
"""
self.encoding = tiktoken.get_encoding(encoding_name)
def encode(self, text: str) -> list[int]:
"""Encode the given text into a list of tokens.
Args
----
text (str): The input text to encode.
Returns
-------
list[int]: A list of tokens representing the encoded text.
"""
return self.encoding.encode(text)
def decode(self, tokens: list[int]) -> str:
"""Decode a list of tokens back into a string.
Args
----
tokens (list[int]): A list of tokens to decode.
Returns
-------
str: The decoded string from the list of tokens.
"""
return self.encoding.decode(tokens)

View File

@ -0,0 +1,53 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Tokenizer Abstract Base Class."""
from abc import ABC, abstractmethod
class Tokenizer(ABC):
"""Tokenizer Abstract Base Class."""
@abstractmethod
def encode(self, text: str) -> list[int]:
"""Encode the given text into a list of tokens.
Args
----
text (str): The input text to encode.
Returns
-------
list[int]: A list of tokens representing the encoded text.
"""
msg = "The encode method must be implemented by subclasses."
raise NotImplementedError(msg)
@abstractmethod
def decode(self, tokens: list[int]) -> str:
"""Decode a list of tokens back into a string.
Args
----
tokens (list[int]): A list of tokens to decode.
Returns
-------
str: The decoded string from the list of tokens.
"""
msg = "The decode method must be implemented by subclasses."
raise NotImplementedError(msg)
def num_tokens(self, text: str) -> int:
"""Return the number of tokens in the given text.
Args
----
text (str): The input text to analyze.
Returns
-------
int: The number of tokens in the input text.
"""
return len(self.encode(text))

View File

@ -43,7 +43,7 @@ dependencies = [
"json-repair>=0.30.3",
"openai>=1.68.0",
"nltk==3.9.1",
"tiktoken>=0.9.0",
"tiktoken>=0.11.0",
# Data-Science
"numpy>=1.25.2",
"graspologic>=3.4.1",
@ -66,6 +66,7 @@ dependencies = [
"tqdm>=4.67.1",
"textblob>=0.18.0.post0",
"spacy>=3.8.4",
"litellm>=1.77.1",
]
[project.optional-dependencies]

View File

@ -9,7 +9,7 @@ import tiktoken
from graphrag.index.text_splitting.text_splitting import (
NoopTextSplitter,
Tokenizer,
TokenChunkerOptions,
TokenTextSplitter,
split_multiple_texts_on_tokens,
split_single_text_on_tokens,
@ -64,7 +64,7 @@ def test_split_text_large_input(mock_split):
@mock.patch("graphrag.index.text_splitting.text_splitting.split_single_text_on_tokens")
@mock.patch("graphrag.index.text_splitting.text_splitting.Tokenizer")
@mock.patch("graphrag.index.text_splitting.text_splitting.TokenChunkerOptions")
def test_token_text_splitter(mock_tokenizer, mock_split_text):
text = "chunk1 chunk2 chunk3"
expected_chunks = ["chunk1", "chunk2", "chunk3"]
@ -80,42 +80,10 @@ def test_token_text_splitter(mock_tokenizer, mock_split_text):
mock_split_text.assert_called_once_with(text=text, tokenizer=mocked_tokenizer)
def test_encode_basic():
splitter = TokenTextSplitter()
result = splitter.encode("abc def")
assert result == [13997, 711], "Encoding failed to return expected tokens"
def test_num_tokens_empty_input():
splitter = TokenTextSplitter()
result = splitter.num_tokens("")
assert result == 0, "Token count for empty input should be 0"
def test_model_name():
splitter = TokenTextSplitter(model_name="gpt-4o")
result = splitter.encode("abc def")
assert result == [26682, 1056], "Encoding failed to return expected tokens"
@mock.patch("tiktoken.encoding_for_model", side_effect=KeyError)
@mock.patch("tiktoken.get_encoding")
def test_model_name_exception(mock_get_encoding, mock_encoding_for_model):
mock_get_encoding.return_value = mock.MagicMock()
TokenTextSplitter(model_name="mock_model", encoding_name="mock_encoding")
mock_get_encoding.assert_called_once_with("mock_encoding")
mock_encoding_for_model.assert_called_once_with("mock_model")
def test_split_single_text_on_tokens():
text = "This is a test text, meaning to be taken seriously by this test only."
mocked_tokenizer = MockTokenizer()
tokenizer = Tokenizer(
tokenizer = TokenChunkerOptions(
chunk_overlap=5,
tokens_per_chunk=10,
decode=mocked_tokenizer.decode,
@ -150,7 +118,7 @@ def test_split_multiple_texts_on_tokens():
mocked_tokenizer = MockTokenizer()
mock_tick = MagicMock()
tokenizer = Tokenizer(
tokenizer = TokenChunkerOptions(
chunk_overlap=5,
tokens_per_chunk=10,
decode=mocked_tokenizer.decode,
@ -173,7 +141,7 @@ def test_split_single_text_on_tokens_no_overlap():
def decode(tokens: list[int]) -> str:
return enc.decode(tokens)
tokenizer = Tokenizer(
tokenizer = TokenChunkerOptions(
chunk_overlap=1,
tokens_per_chunk=2,
decode=decode,

View File

@ -0,0 +1,2 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

View File

@ -0,0 +1,368 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Test LiteLLM Rate Limiter."""
import threading
import time
from math import ceil
from queue import Queue
import pytest
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter import (
RateLimiter,
)
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter_factory import (
RateLimiterFactory,
)
from tests.unit.litellm_services.utils import (
assert_max_num_values_per_period,
assert_stagger,
bin_time_intervals,
)
rate_limiter_factory = RateLimiterFactory()
_period_in_seconds = 1
_rpm = 4
_tpm = 75
_tokens_per_request = 25
_stagger = _period_in_seconds / _rpm
_num_requests = 10
def test_binning():
"""Test binning timings into 1-second intervals."""
values = [0.1, 0.2, 0.3, 0.4, 1.1, 1.2, 1.3, 1.4, 5.1]
binned_values = bin_time_intervals(values, 1)
assert binned_values == [
[0.1, 0.2, 0.3, 0.4],
[1.1, 1.2, 1.3, 1.4],
[],
[],
[],
[5.1],
]
def test_rate_limiter_validation():
"""Test that the rate limiter can be created with valid parameters."""
# Valid parameters
rate_limiter = rate_limiter_factory.create(
strategy="static", rpm=60, tpm=10000, period_in_seconds=60
)
assert rate_limiter is not None
# Invalid strategy
with pytest.raises(
ValueError,
match=r"Strategy 'invalid_strategy' is not registered.",
):
rate_limiter_factory.create(strategy="invalid_strategy", rpm=60, tpm=10000)
# Both rpm and tpm are None
with pytest.raises(
ValueError,
match=r"Both TPM and RPM cannot be None \(disabled\), one or both must be set to a positive integer.",
):
rate_limiter_factory.create(strategy="static")
# Invalid rpm
with pytest.raises(
ValueError,
match=r"RPM and TPM must be either None \(disabled\) or positive integers.",
):
rate_limiter_factory.create(strategy="static", rpm=-10)
# Invalid tpm
with pytest.raises(
ValueError,
match=r"RPM and TPM must be either None \(disabled\) or positive integers.",
):
rate_limiter_factory.create(strategy="static", tpm=-10)
# Invalid period_in_seconds
with pytest.raises(
ValueError, match=r"Period in seconds must be a positive integer."
):
rate_limiter_factory.create(strategy="static", rpm=10, period_in_seconds=-10)
def test_rpm():
"""Test that the rate limiter enforces RPM limits."""
rate_limiter = rate_limiter_factory.create(
strategy="static", rpm=_rpm, period_in_seconds=_period_in_seconds
)
time_values: list[float] = []
start_time = time.time()
for _ in range(_num_requests):
with rate_limiter.acquire(token_count=_tokens_per_request):
time_values.append(time.time() - start_time)
assert len(time_values) == _num_requests
binned_time_values = bin_time_intervals(time_values, _period_in_seconds)
"""
With _num_requests = 10 and _rpm = 4, we expect the requests to be
distributed across ceil(10/4) = 3 bins:
with a stagger of 1/4 = 0.25 seconds between requests.
"""
expected_num_bins = ceil(_num_requests / _rpm)
assert len(binned_time_values) == expected_num_bins
assert_max_num_values_per_period(binned_time_values, _rpm)
assert_stagger(time_values, _stagger)
def test_tpm():
"""Test that the rate limiter enforces TPM limits."""
rate_limiter = rate_limiter_factory.create(
strategy="static", tpm=_tpm, period_in_seconds=_period_in_seconds
)
time_values: list[float] = []
start_time = time.time()
for _ in range(_num_requests):
with rate_limiter.acquire(token_count=_tokens_per_request):
time_values.append(time.time() - start_time)
assert len(time_values) == _num_requests
binned_time_values = bin_time_intervals(time_values, _period_in_seconds)
"""
With _num_requests = 10, _tpm = 75 and _tokens_per_request = 25, we expect the requests to be
distributed across ceil( (10 * 25) / 75) ) = 4 bins
and max requests per bin = (75 / 25) = 3 requests per bin.
"""
expected_num_bins = ceil((_num_requests * _tokens_per_request) / _tpm)
assert len(binned_time_values) == expected_num_bins
max_num_of_requests_per_bin = _tpm // _tokens_per_request
assert_max_num_values_per_period(binned_time_values, max_num_of_requests_per_bin)
def test_token_in_request_exceeds_tpm():
"""Test that the rate limiter allows for requests that use more tokens than the TPM.
A rate limiter could be configured with a tpm of 1000 but a request may use 2000 tokens,
greater than the tpm limit but still below the context window limit of the underlying model.
In this case, the request should still be allowed to proceed but may take up its own rate limit bin.
"""
rate_limiter = rate_limiter_factory.create(
strategy="static", tpm=_tpm, period_in_seconds=_period_in_seconds
)
time_values: list[float] = []
start_time = time.time()
for _ in range(2):
with rate_limiter.acquire(token_count=_tpm * 2):
time_values.append(time.time() - start_time)
assert len(time_values) == 2
binned_time_values = bin_time_intervals(time_values, _period_in_seconds)
"""
Since each request exceeds the tpm, we expect each request to still be fired off but to be in its own bin
"""
assert len(binned_time_values) == 2
assert_max_num_values_per_period(binned_time_values, 1)
def test_rpm_and_tpm_with_rpm_as_limiting_factor():
"""Test that the rate limiter enforces RPM and TPM limits."""
rate_limiter = rate_limiter_factory.create(
strategy="static", rpm=_rpm, tpm=_tpm, period_in_seconds=_period_in_seconds
)
time_values: list[float] = []
start_time = time.time()
for _ in range(_num_requests):
# Use 0 tokens per request to simulate RPM as the limiting factor
with rate_limiter.acquire(token_count=0):
time_values.append(time.time() - start_time)
assert len(time_values) == _num_requests
binned_time_values = bin_time_intervals(time_values, _period_in_seconds)
"""
With _num_requests = 10 and _rpm = 4, we expect the requests to be
distributed across ceil(10/4) = 3 bins:
with a stagger of 1/4 = 0.25 seconds between requests.
"""
expected_num_bins = ceil(_num_requests / _rpm)
assert len(binned_time_values) == expected_num_bins
assert_max_num_values_per_period(binned_time_values, _rpm)
assert_stagger(time_values, _stagger)
def test_rpm_and_tpm_with_tpm_as_limiting_factor():
"""Test that the rate limiter enforces TPM limits."""
rate_limiter = rate_limiter_factory.create(
strategy="static", rpm=_rpm, tpm=_tpm, period_in_seconds=_period_in_seconds
)
time_values: list[float] = []
start_time = time.time()
for _ in range(_num_requests):
with rate_limiter.acquire(token_count=_tokens_per_request):
time_values.append(time.time() - start_time)
assert len(time_values) == _num_requests
binned_time_values = bin_time_intervals(time_values, _period_in_seconds)
"""
With _num_requests = 10, _tpm = 75 and _tokens_per_request = 25, we expect the requests to be
distributed across ceil( (10 * 25) / 75) ) = 4 bins
and max requests per bin = (75 / 25) = 3 requests per bin.
"""
expected_num_bins = ceil((_num_requests * _tokens_per_request) / _tpm)
assert len(binned_time_values) == expected_num_bins
max_num_of_requests_per_bin = _tpm // _tokens_per_request
assert_max_num_values_per_period(binned_time_values, max_num_of_requests_per_bin)
assert_stagger(time_values, _stagger)
def _run_rate_limiter(
rate_limiter: RateLimiter,
# Acquire cost
input_queue: Queue[int | None],
# time value
output_queue: Queue[float | None],
):
while True:
token_count = input_queue.get()
if token_count is None:
break
with rate_limiter.acquire(token_count=token_count):
output_queue.put(time.time())
def test_rpm_threaded():
"""Test that the rate limiter enforces RPM limits in a threaded environment."""
rate_limiter = rate_limiter_factory.create(
strategy="static", rpm=_rpm, tpm=_tpm, period_in_seconds=_period_in_seconds
)
input_queue: Queue[int | None] = Queue()
output_queue: Queue[float | None] = Queue()
# Spin up threads for half the number of requests
threads = [
threading.Thread(
target=_run_rate_limiter,
args=(rate_limiter, input_queue, output_queue),
)
for _ in range(_num_requests // 2) # Create 5 threads
]
for thread in threads:
thread.start()
start_time = time.time()
for _ in range(_num_requests):
# Use 0 tokens per request to simulate RPM as the limiting factor
input_queue.put(0)
# Signal threads to stop
for _ in range(len(threads)):
input_queue.put(None)
for thread in threads:
thread.join()
output_queue.put(None) # Signal end of output
time_values = []
while True:
time_value = output_queue.get()
if time_value is None:
break
time_values.append(time_value - start_time)
time_values.sort()
assert len(time_values) == _num_requests
binned_time_values = bin_time_intervals(time_values, _period_in_seconds)
"""
With _num_requests = 10 and _rpm = 4, we expect the requests to be
distributed across ceil(10/4) = 3 bins:
with a stagger of 1/4 = 0.25 seconds between requests.
"""
expected_num_bins = ceil(_num_requests / _rpm)
assert len(binned_time_values) == expected_num_bins
assert_max_num_values_per_period(binned_time_values, _rpm)
assert_stagger(time_values, _stagger)
def test_tpm_threaded():
"""Test that the rate limiter enforces TPM limits in a threaded environment."""
rate_limiter = rate_limiter_factory.create(
strategy="static", rpm=_rpm, tpm=_tpm, period_in_seconds=_period_in_seconds
)
input_queue: Queue[int | None] = Queue()
output_queue: Queue[float | None] = Queue()
# Spin up threads for half the number of requests
threads = [
threading.Thread(
target=_run_rate_limiter,
args=(rate_limiter, input_queue, output_queue),
)
for _ in range(_num_requests // 2) # Create 5 threads
]
for thread in threads:
thread.start()
start_time = time.time()
for _ in range(_num_requests):
input_queue.put(_tokens_per_request)
# Signal threads to stop
for _ in range(len(threads)):
input_queue.put(None)
for thread in threads:
thread.join()
output_queue.put(None) # Signal end of output
time_values = []
while True:
time_value = output_queue.get()
if time_value is None:
break
time_values.append(time_value - start_time)
time_values.sort()
assert len(time_values) == _num_requests
binned_time_values = bin_time_intervals(time_values, _period_in_seconds)
"""
With _num_requests = 10, _tpm = 75 and _tokens_per_request = 25, we expect the requests to be
distributed across ceil( (10 * 25) / 75) ) = 4 bins
and max requests per bin = (75 / 25) = 3 requests per bin.
"""
expected_num_bins = ceil((_num_requests * _tokens_per_request) / _tpm)
assert len(binned_time_values) == expected_num_bins
max_num_of_requests_per_bin = _tpm // _tokens_per_request
assert_max_num_values_per_period(binned_time_values, max_num_of_requests_per_bin)
assert_stagger(time_values, _stagger)

View File

@ -0,0 +1,152 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Test LiteLLM Retries."""
import time
import pytest
from graphrag.language_model.providers.litellm.services.retry.retry_factory import (
RetryFactory,
)
retry_factory = RetryFactory()
@pytest.mark.parametrize(
("strategy", "max_attempts", "max_retry_wait", "expected_time"),
[
(
"native",
3, # 3 retries
0, # native retry does not adhere to max_retry_wait
0, # immediate retry, expect 0 seconds elapsed time
),
(
"exponential_backoff",
3, # 3 retries
0, # exponential retry does not adhere to max_retry_wait
14, # (2^1 + jitter) + (2^2 + jitter) + (2^3 + jitter) = 2 + 4 + 8 + 3*jitter = 14 seconds min total runtime
),
(
"random_wait",
3, # 3 retries
2, # random wait [0, 2] seconds
0, # unpredictable, don't know what the total runtime will be
),
(
"incremental_wait",
3, # 3 retries
3, # wait for a max of 3 seconds on a single retry.
6, # Wait 3/3 * 1 on first retry, 3/3 * 2 on second, 3/3 * 3 on third, 1 + 2 + 3 = 6 seconds total runtime.
),
],
)
def test_retries(
strategy: str, max_attempts: int, max_retry_wait: int, expected_time: float
) -> None:
"""
Test various retry strategies with various configurations.
Args
----
strategy: The retry strategy to use.
max_attempts: The maximum number of retry attempts.
max_retry_wait: The maximum wait time between retries.
"""
retry_service = retry_factory.create(
strategy=strategy,
max_attempts=max_attempts,
max_retry_wait=max_retry_wait,
)
retries = 0
def mock_func():
nonlocal retries
retries += 1
msg = "Mock error for testing retries"
raise ValueError(msg)
start_time = time.time()
with pytest.raises(ValueError, match="Mock error for testing retries"):
retry_service.retry(func=mock_func)
elapsed_time = time.time() - start_time
# subtract 1 from retries because the first call is not a retry
assert retries - 1 == max_attempts, (
f"Expected {max_attempts} retries, got {retries}"
)
assert elapsed_time >= expected_time, (
f"Expected elapsed time >= {expected_time}, got {elapsed_time}"
)
@pytest.mark.parametrize(
("strategy", "max_attempts", "max_retry_wait", "expected_time"),
[
(
"native",
3, # 3 retries
0, # native retry does not adhere to max_retry_wait
0, # immediate retry, expect 0 seconds elapsed time
),
(
"exponential_backoff",
3, # 3 retries
0, # exponential retry does not adhere to max_retry_wait
14, # (2^1 + jitter) + (2^2 + jitter) + (2^3 + jitter) = 2 + 4 + 8 + 3*jitter = 14 seconds min total runtime
),
(
"random_wait",
3, # 3 retries
2, # random wait [0, 2] seconds
0, # unpredictable, don't know what the total runtime will be
),
(
"incremental_wait",
3, # 3 retries
3, # wait for a max of 3 seconds on a single retry.
6, # Wait 3/3 * 1 on first retry, 3/3 * 2 on second, 3/3 * 3 on third, 1 + 2 + 3 = 6 seconds total runtime.
),
],
)
async def test_retries_async(
strategy: str, max_attempts: int, max_retry_wait: int, expected_time: float
) -> None:
"""
Test various retry strategies with various configurations.
Args
----
strategy: The retry strategy to use.
max_attempts: The maximum number of retry attempts.
max_retry_wait: The maximum wait time between retries.
"""
retry_service = retry_factory.create(
strategy=strategy,
max_attempts=max_attempts,
max_retry_wait=max_retry_wait,
)
retries = 0
async def mock_func(): # noqa: RUF029
nonlocal retries
retries += 1
msg = "Mock error for testing retries"
raise ValueError(msg)
start_time = time.time()
with pytest.raises(ValueError, match="Mock error for testing retries"):
await retry_service.aretry(func=mock_func)
elapsed_time = time.time() - start_time
# subtract 1 from retries because the first call is not a retry
assert retries - 1 == max_attempts, (
f"Expected {max_attempts} retries, got {retries}"
)
assert elapsed_time >= expected_time, (
f"Expected elapsed time >= {expected_time}, got {elapsed_time}"
)

View File

@ -0,0 +1,37 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM Test Utilities."""
def bin_time_intervals(
time_values: list[float], time_interval: int
) -> list[list[float]]:
"""Bin values."""
bins: list[list[float]] = []
bin_number = 0
for time_value in time_values:
upper_bound = (bin_number * time_interval) + time_interval
while time_value >= upper_bound:
bin_number += 1
upper_bound = (bin_number * time_interval) + time_interval
while len(bins) <= bin_number:
bins.append([])
bins[bin_number].append(time_value)
return bins
def assert_max_num_values_per_period(
periods: list[list[float]], max_values_per_period: int
):
"""Assert the number of values per period."""
for period in periods:
assert len(period) <= max_values_per_period
def assert_stagger(time_values: list[float], stagger: float):
"""Assert stagger."""
for i in range(1, len(time_values)):
assert time_values[i] - time_values[i - 1] >= stagger

View File

@ -0,0 +1,18 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.tokenizer.get_tokenizer import get_tokenizer
def test_encode_basic():
tokenizer = get_tokenizer()
result = tokenizer.encode("abc def")
assert result == [13997, 711], "Encoding failed to return expected tokens"
def test_num_tokens_empty_input():
tokenizer = get_tokenizer()
result = len(tokenizer.encode(""))
assert result == 0, "Token count for empty input should be 0"

3761
uv.lock generated

File diff suppressed because it is too large Load Diff