mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
152 lines
4.7 KiB
Python
152 lines
4.7 KiB
Python
# Copyright (c) 2024 Microsoft Corporation.
|
|
# Licensed under the MIT License
|
|
|
|
"""Test LiteLLM Retries."""
|
|
|
|
import time
|
|
|
|
import pytest
|
|
from graphrag.language_model.providers.litellm.services.retry.retry_factory import (
|
|
RetryFactory,
|
|
)
|
|
|
|
retry_factory = RetryFactory()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("strategy", "max_retries", "max_retry_wait", "expected_time"),
|
|
[
|
|
(
|
|
"native",
|
|
3, # 3 retries
|
|
0, # native retry does not adhere to max_retry_wait
|
|
0, # immediate retry, expect 0 seconds elapsed time
|
|
),
|
|
(
|
|
"exponential_backoff",
|
|
3, # 3 retries
|
|
0, # exponential retry does not adhere to max_retry_wait
|
|
14, # (2^1 + jitter) + (2^2 + jitter) + (2^3 + jitter) = 2 + 4 + 8 + 3*jitter = 14 seconds min total runtime
|
|
),
|
|
(
|
|
"random_wait",
|
|
3, # 3 retries
|
|
2, # random wait [0, 2] seconds
|
|
0, # unpredictable, don't know what the total runtime will be
|
|
),
|
|
(
|
|
"incremental_wait",
|
|
3, # 3 retries
|
|
3, # wait for a max of 3 seconds on a single retry.
|
|
6, # Wait 3/3 * 1 on first retry, 3/3 * 2 on second, 3/3 * 3 on third, 1 + 2 + 3 = 6 seconds total runtime.
|
|
),
|
|
],
|
|
)
|
|
def test_retries(
|
|
strategy: str, max_retries: int, max_retry_wait: int, expected_time: float
|
|
) -> None:
|
|
"""
|
|
Test various retry strategies with various configurations.
|
|
|
|
Args
|
|
----
|
|
strategy: The retry strategy to use.
|
|
max_retries: The maximum number of retry attempts.
|
|
max_retry_wait: The maximum wait time between retries.
|
|
"""
|
|
retry_service = retry_factory.create(
|
|
strategy=strategy,
|
|
init_args={
|
|
"max_retries": max_retries,
|
|
"max_retry_wait": max_retry_wait,
|
|
},
|
|
)
|
|
|
|
retries = 0
|
|
|
|
def mock_func():
|
|
nonlocal retries
|
|
retries += 1
|
|
msg = "Mock error for testing retries"
|
|
raise ValueError(msg)
|
|
|
|
start_time = time.time()
|
|
with pytest.raises(ValueError, match="Mock error for testing retries"):
|
|
retry_service.retry(func=mock_func)
|
|
elapsed_time = time.time() - start_time
|
|
|
|
# subtract 1 from retries because the first call is not a retry
|
|
assert retries - 1 == max_retries, f"Expected {max_retries} retries, got {retries}"
|
|
assert elapsed_time >= expected_time, (
|
|
f"Expected elapsed time >= {expected_time}, got {elapsed_time}"
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("strategy", "max_retries", "max_retry_wait", "expected_time"),
|
|
[
|
|
(
|
|
"native",
|
|
3, # 3 retries
|
|
0, # native retry does not adhere to max_retry_wait
|
|
0, # immediate retry, expect 0 seconds elapsed time
|
|
),
|
|
(
|
|
"exponential_backoff",
|
|
3, # 3 retries
|
|
0, # exponential retry does not adhere to max_retry_wait
|
|
14, # (2^1 + jitter) + (2^2 + jitter) + (2^3 + jitter) = 2 + 4 + 8 + 3*jitter = 14 seconds min total runtime
|
|
),
|
|
(
|
|
"random_wait",
|
|
3, # 3 retries
|
|
2, # random wait [0, 2] seconds
|
|
0, # unpredictable, don't know what the total runtime will be
|
|
),
|
|
(
|
|
"incremental_wait",
|
|
3, # 3 retries
|
|
3, # wait for a max of 3 seconds on a single retry.
|
|
6, # Wait 3/3 * 1 on first retry, 3/3 * 2 on second, 3/3 * 3 on third, 1 + 2 + 3 = 6 seconds total runtime.
|
|
),
|
|
],
|
|
)
|
|
async def test_retries_async(
|
|
strategy: str, max_retries: int, max_retry_wait: int, expected_time: float
|
|
) -> None:
|
|
"""
|
|
Test various retry strategies with various configurations.
|
|
|
|
Args
|
|
----
|
|
strategy: The retry strategy to use.
|
|
max_retries: The maximum number of retry attempts.
|
|
max_retry_wait: The maximum wait time between retries.
|
|
"""
|
|
retry_service = retry_factory.create(
|
|
strategy=strategy,
|
|
init_args={
|
|
"max_retries": max_retries,
|
|
"max_retry_wait": max_retry_wait,
|
|
},
|
|
)
|
|
|
|
retries = 0
|
|
|
|
async def mock_func(): # noqa: RUF029
|
|
nonlocal retries
|
|
retries += 1
|
|
msg = "Mock error for testing retries"
|
|
raise ValueError(msg)
|
|
|
|
start_time = time.time()
|
|
with pytest.raises(ValueError, match="Mock error for testing retries"):
|
|
await retry_service.aretry(func=mock_func)
|
|
elapsed_time = time.time() - start_time
|
|
|
|
# subtract 1 from retries because the first call is not a retry
|
|
assert retries - 1 == max_retries, f"Expected {max_retries} retries, got {retries}"
|
|
assert elapsed_time >= expected_time, (
|
|
f"Expected elapsed time >= {expected_time}, got {elapsed_time}"
|
|
)
|