graphrag/graphrag/language_model/providers/litellm/services/retry/exponential_retry.py
Derek Worthen 2b70e4a4f3
Tokenizer (#2051)
* Add LiteLLM chat and embedding model providers.

* Fix code review findings.

* Add litellm.

* Fix formatting.

* Update dictionary.

* Update litellm.

* Fix embedding.

* Remove manual use of tiktoken and replace with
Tokenizer interface. Adds support for encoding
and decoding the models supported by litellm.

* Update litellm.

* Configure litellm to drop unsupported params.

* Cleanup semversioner release notes.

* Add num_tokens util to Tokenizer interface.

* Update litellm service factories.

* Cleanup litellm chat/embedding model argument assignment.

* Update chat and embedding type field for litellm use and future migration away from fnllm.

* Flatten litellm service organization.

* Update litellm.

* Update litellm factory validation.

* Flatten litellm rate limit service organization.

* Update rate limiter - disable with None/null instead of 0.

* Fix usage of get_tokenizer.

* Update litellm service registrations.

* Add jitter to exponential retry.

* Update validation.

* Update validation.

* Add litellm request logging layer.

* Update cache key.

* Update defaults.

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
2025-09-22 13:55:14 -06:00

84 lines
2.9 KiB
Python

# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM Exponential Retry Service."""
import asyncio
import logging
import random
import time
from collections.abc import Awaitable, Callable
from typing import Any
from graphrag.language_model.providers.litellm.services.retry.retry import Retry
logger = logging.getLogger(__name__)
class ExponentialRetry(Retry):
"""LiteLLM Exponential Retry Service."""
def __init__(
self,
*,
max_attempts: int = 5,
base_delay: float = 2.0,
jitter: bool = True,
**kwargs: Any,
):
if max_attempts <= 0:
msg = "max_attempts must be greater than 0."
raise ValueError(msg)
if base_delay <= 1.0:
msg = "base_delay must be greater than 1.0."
raise ValueError(msg)
self._max_attempts = max_attempts
self._base_delay = base_delay
self._jitter = jitter
def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any:
"""Retry a synchronous function."""
retries = 0
delay = 1.0 # Initial delay in seconds
while True:
try:
return func(**kwargs)
except Exception as e:
if retries >= self._max_attempts:
logger.exception(
f"ExponentialRetry: Max retries exceeded, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)
raise
retries += 1
delay *= self._base_delay
logger.exception(
f"ExponentialRetry: Request failed, retrying, retries={retries}, delay={delay}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)
time.sleep(delay + (self._jitter * random.uniform(0, 1))) # noqa: S311
async def aretry(
self,
func: Callable[..., Awaitable[Any]],
**kwargs: Any,
) -> Any:
"""Retry an asynchronous function."""
retries = 0
delay = 1.0 # Initial delay in seconds
while True:
try:
return await func(**kwargs)
except Exception as e:
if retries >= self._max_attempts:
logger.exception(
f"ExponentialRetry: Max retries exceeded, retries={retries}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)
raise
retries += 1
delay *= self._base_delay
logger.exception(
f"ExponentialRetry: Request failed, retrying, retries={retries}, delay={delay}, max_retries={self._max_attempts}, exception={e}", # noqa: G004, TRY401
)
await asyncio.sleep(delay + (self._jitter * random.uniform(0, 1))) # noqa: S311