graphrag/tests/integration/language_model/test_retries.py
Nathan Evans 1bb9fa8e13
Some checks are pending
Python CI / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.11) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.10) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.11) (push) Waiting to run
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Unified factory (#2105)
* Simplify Factory interface

* Migrate CacheFactory to standard base class

* Migrate LoggerFactory to standard base class

* Migrate StorageFactory to standard base class

* Migrate VectorStoreFactory to standard base class

* Update vector store example notebook

* Delete notebook outputs

* Move default providers into factories

* Move retry/limit tests into integ

* Split language model factories

* Set smoke test tpm/rpm

* Fix factory integ tests

* Add method to smoke test, switch text to 'fast'

* Fix text smoke config for fast workflow

* Add new workflows to text smoke test

* Convert input readers to a proper factory

* Remove covariates from fast smoke test

* Update docs for input factory

* Bump smoke runtime

* Even longer runtime

* min-csv timeout

* Remove unnecessary lambdas
2025-10-20 12:05:27 -07:00

153 lines
4.7 KiB
Python

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Test LiteLLM Retries."""
import time
import pytest
from graphrag.language_model.providers.litellm.services.retry.retry_factory import (
RetryFactory,
)
retry_factory = RetryFactory()
@pytest.mark.parametrize(
("strategy", "max_retries", "max_retry_wait", "expected_time"),
[
(
"native",
3, # 3 retries
0, # native retry does not adhere to max_retry_wait
0, # immediate retry, expect 0 seconds elapsed time
),
(
"exponential_backoff",
3, # 3 retries
0, # exponential retry does not adhere to max_retry_wait
14, # (2^1 + jitter) + (2^2 + jitter) + (2^3 + jitter) = 2 + 4 + 8 + 3*jitter = 14 seconds min total runtime
),
(
"random_wait",
3, # 3 retries
2, # random wait [0, 2] seconds
0, # unpredictable, don't know what the total runtime will be
),
(
"incremental_wait",
3, # 3 retries
3, # wait for a max of 3 seconds on a single retry.
6, # Wait 3/3 * 1 on first retry, 3/3 * 2 on second, 3/3 * 3 on third, 1 + 2 + 3 = 6 seconds total runtime.
),
],
)
def test_retries(
strategy: str, max_retries: int, max_retry_wait: int, expected_time: float
) -> None:
"""
Test various retry strategies with various configurations.
Args
----
strategy: The retry strategy to use.
max_retries: The maximum number of retry attempts.
max_retry_wait: The maximum wait time between retries.
"""
retry_service = retry_factory.create(
strategy=strategy,
init_args={
"max_retries": max_retries,
"max_retry_wait": max_retry_wait,
},
)
retries = 0
def mock_func():
nonlocal retries
retries += 1
msg = "Mock error for testing retries"
raise ValueError(msg)
start_time = time.time()
with pytest.raises(ValueError, match="Mock error for testing retries"):
retry_service.retry(func=mock_func)
elapsed_time = time.time() - start_time
# subtract 1 from retries because the first call is not a retry
assert retries - 1 == max_retries, f"Expected {max_retries} retries, got {retries}"
assert elapsed_time >= expected_time, (
f"Expected elapsed time >= {expected_time}, got {elapsed_time}"
)
@pytest.mark.parametrize(
("strategy", "max_retries", "max_retry_wait", "expected_time"),
[
(
"native",
3, # 3 retries
0, # native retry does not adhere to max_retry_wait
0, # immediate retry, expect 0 seconds elapsed time
),
(
"exponential_backoff",
3, # 3 retries
0, # exponential retry does not adhere to max_retry_wait
14, # (2^1 + jitter) + (2^2 + jitter) + (2^3 + jitter) = 2 + 4 + 8 + 3*jitter = 14 seconds min total runtime
),
(
"random_wait",
3, # 3 retries
2, # random wait [0, 2] seconds
0, # unpredictable, don't know what the total runtime will be
),
(
"incremental_wait",
3, # 3 retries
3, # wait for a max of 3 seconds on a single retry.
6, # Wait 3/3 * 1 on first retry, 3/3 * 2 on second, 3/3 * 3 on third, 1 + 2 + 3 = 6 seconds total runtime.
),
],
)
async def test_retries_async(
strategy: str, max_retries: int, max_retry_wait: int, expected_time: float
) -> None:
"""
Test various retry strategies with various configurations.
Args
----
strategy: The retry strategy to use.
max_retries: The maximum number of retry attempts.
max_retry_wait: The maximum wait time between retries.
"""
retry_service = retry_factory.create(
strategy=strategy,
init_args={
"max_retries": max_retries,
"max_retry_wait": max_retry_wait,
},
)
retries = 0
async def mock_func(): # noqa: RUF029
nonlocal retries
retries += 1
msg = "Mock error for testing retries"
raise ValueError(msg)
start_time = time.time()
with pytest.raises(ValueError, match="Mock error for testing retries"):
await retry_service.aretry(func=mock_func)
elapsed_time = time.time() - start_time
# subtract 1 from retries because the first call is not a retry
assert retries - 1 == max_retries, f"Expected {max_retries} retries, got {retries}"
assert elapsed_time >= expected_time, (
f"Expected elapsed time >= {expected_time}, got {elapsed_time}"
)