-
Notifications
You must be signed in to change notification settings - Fork 20
Expand file tree
/
Copy pathtest_retries.py
More file actions
111 lines (84 loc) · 3.25 KB
/
test_retries.py
File metadata and controls
111 lines (84 loc) · 3.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""Tests for retry logic covering all TransportError subclasses."""
import asyncio
from unittest.mock import MagicMock
import httpx
import pytest
from unstructured_client.utils.retries import (
BackoffStrategy,
PermanentError,
Retries,
RetryConfig,
retry,
retry_async,
)
def _make_retries(retry_connection_errors: bool) -> Retries:
return Retries(
config=RetryConfig(
strategy="backoff",
backoff=BackoffStrategy(
initial_interval=100,
max_interval=200,
exponent=1.5,
max_elapsed_time=5000,
),
retry_connection_errors=retry_connection_errors,
),
status_codes=[],
)
# All TransportError subclasses that should be retried
TRANSPORT_ERRORS = [
(httpx.ConnectError, "Connection refused"),
(httpx.RemoteProtocolError, "Server disconnected without sending a response."),
(httpx.ReadError, ""),
(httpx.WriteError, ""),
(httpx.ConnectTimeout, "Timed out"),
(httpx.ReadTimeout, "Timed out"),
]
class TestTransportErrorRetry:
"""All httpx.TransportError subclasses should be retried when retry_connection_errors=True."""
@pytest.mark.parametrize("exc_class,msg", TRANSPORT_ERRORS)
def test_transport_error_retried_when_enabled(self, exc_class, msg):
retries_config = _make_retries(retry_connection_errors=True)
mock_response = MagicMock(spec=httpx.Response)
mock_response.status_code = 200
call_count = 0
def func():
nonlocal call_count
call_count += 1
if call_count == 1:
raise exc_class(msg)
return mock_response
result = retry(func, retries_config)
assert result.status_code == 200
assert call_count == 2
@pytest.mark.parametrize("exc_class,msg", TRANSPORT_ERRORS)
def test_transport_error_not_retried_when_disabled(self, exc_class, msg):
retries_config = _make_retries(retry_connection_errors=False)
def func():
raise exc_class(msg)
with pytest.raises(exc_class):
retry(func, retries_config)
class TestTransportErrorRetryAsync:
"""Async: All httpx.TransportError subclasses should be retried."""
@pytest.mark.parametrize("exc_class,msg", TRANSPORT_ERRORS)
def test_transport_error_retried_async(self, exc_class, msg):
retries_config = _make_retries(retry_connection_errors=True)
mock_response = MagicMock(spec=httpx.Response)
mock_response.status_code = 200
call_count = 0
async def func():
nonlocal call_count
call_count += 1
if call_count == 1:
raise exc_class(msg)
return mock_response
result = asyncio.run(retry_async(func, retries_config))
assert result.status_code == 200
assert call_count == 2
@pytest.mark.parametrize("exc_class,msg", TRANSPORT_ERRORS)
def test_transport_error_not_retried_async_when_disabled(self, exc_class, msg):
retries_config = _make_retries(retry_connection_errors=False)
async def func():
raise exc_class(msg)
with pytest.raises(exc_class):
asyncio.run(retry_async(func, retries_config))