-
Notifications
You must be signed in to change notification settings - Fork 30
Expand file tree
/
Copy pathconftest.py
More file actions
324 lines (262 loc) · 10.5 KB
/
conftest.py
File metadata and controls
324 lines (262 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
from typing import (
Any,
Awaitable,
Callable,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Union,
cast,
)
from unittest.mock import AsyncMock, MagicMock
import httpx
import pytest
from tests.utils.client_configuration import ClientConfiguration
from tests.utils.list_resource import list_data_to_dicts, list_response_of
from tests.utils.syncify import syncify
from tests.types.test_auto_pagination_function import TestAutoPaginationFunction
from workos.types.list_resource import WorkOSListResource
from workos.utils._base_http_client import DEFAULT_REQUEST_TIMEOUT
from workos.utils.http_client import AsyncHTTPClient, HTTPClient, SyncHTTPClient
from workos.utils.request_helper import DEFAULT_LIST_RESPONSE_LIMIT
from jwt import PyJWKClient
from unittest.mock import Mock, patch
from functools import wraps
def _get_test_client_setup(
http_client_class_name: str,
) -> Tuple[Literal["async", "sync"], ClientConfiguration, HTTPClient]:
base_url = "https://api.workos.test/"
client_id = "client_b27needthisforssotemxo"
setup_name = None
if http_client_class_name == "AsyncHTTPClient":
http_client = AsyncHTTPClient(
api_key="sk_test",
base_url=base_url,
client_id=client_id,
version="test",
)
setup_name = "async"
elif http_client_class_name == "SyncHTTPClient":
http_client = SyncHTTPClient(
api_key="sk_test",
base_url=base_url,
client_id=client_id,
version="test",
)
setup_name = "sync"
else:
raise ValueError(
f"Invalid HTTP client for test module setup: {http_client_class_name}"
)
client_configuration = ClientConfiguration(
base_url=base_url, client_id=client_id, request_timeout=DEFAULT_REQUEST_TIMEOUT
)
return setup_name, client_configuration, http_client
def pytest_configure(config) -> None:
config.addinivalue_line(
"markers",
"sync_and_async(): mark test to run both sync and async module versions",
)
def pytest_generate_tests(metafunc: pytest.Metafunc):
for marker in metafunc.definition.iter_markers(name="sync_and_async"):
if marker.name == "sync_and_async":
if len(marker.args) == 0:
raise ValueError(
"sync_and_async marker requires argument representing list of modules."
)
# Take in args as a list of module classes. For example:
# @pytest.mark.sync_and_async(Events, AsyncEvents) -> [Events, AsyncEvents]
module_classes = marker.args
ids = []
arg_values = []
for module_class in module_classes:
if module_class is None:
raise ValueError(
f"Invalid module class for sync_and_async marker: {module_class}"
)
# Pull the HTTP client type from the module class annotations and use that
# to pass in the proper test HTTP client
http_client_name = module_class.__annotations__["_http_client"].__name__
setup_name, client_configuration, http_client = _get_test_client_setup(
http_client_name
)
class_kwargs: Mapping[str, Any] = {"http_client": http_client}
if module_class.__init__.__annotations__.get(
"client_configuration", None
):
class_kwargs["client_configuration"] = client_configuration
module_instance = module_class(**class_kwargs)
ids.append(setup_name) # sync or async will be the test ID
arg_names = ["module_instance"]
arg_values.append([module_instance])
metafunc.parametrize(
argnames=arg_names, argvalues=arg_values, ids=ids, scope="class"
)
@pytest.fixture
def sync_http_client_for_test():
_, _, http_client = _get_test_client_setup("SyncHTTPClient")
return http_client
@pytest.fixture
def sync_client_configuration_and_http_client_for_test():
_, client_configuration, http_client = _get_test_client_setup("SyncHTTPClient")
return client_configuration, http_client
@pytest.fixture
def mock_http_client_with_response(monkeypatch):
def inner(
http_client: HTTPClient,
response_dict: Optional[dict] = None,
status_code: int = 200,
headers: Optional[Mapping[str, str]] = None,
):
mock_class = (
AsyncMock if isinstance(http_client, AsyncHTTPClient) else MagicMock
)
mock = mock_class(
return_value=httpx.Response(
status_code=status_code, headers=headers, json=response_dict
),
)
monkeypatch.setattr(http_client._client, "request", mock)
return inner
@pytest.fixture
def capture_and_mock_http_client_request(monkeypatch):
def inner(
http_client: HTTPClient,
response_dict: Optional[dict] = None,
status_code: int = 200,
headers: Optional[Mapping[str, str]] = None,
):
request_kwargs = {}
def capture_and_mock(*args, **kwargs):
request_kwargs.update(kwargs)
return httpx.Response(
status_code=status_code,
headers=headers,
json=response_dict,
)
mock_class = (
AsyncMock if isinstance(http_client, AsyncHTTPClient) else MagicMock
)
mock = mock_class(side_effect=capture_and_mock)
monkeypatch.setattr(http_client._client, "request", mock)
return request_kwargs
return inner
@pytest.fixture
def capture_and_mock_pagination_request_for_http_client(monkeypatch):
# Mocking pagination correctly requires us to index into a list of data
# and correctly set the before and after metadata in the response.
def inner(
http_client: HTTPClient,
data_list: list,
status_code: int = 200,
headers: Optional[Mapping[str, str]] = None,
):
request_kwargs = {}
# For convenient index lookup, store the list of object IDs.
data_ids = list(map(lambda x: x["id"], data_list))
def mock_function(*args, **kwargs):
request_kwargs.update(kwargs)
params = kwargs.get("params") or {}
request_after = params.get("after", None)
limit = params.get("limit", 10)
if request_after is None:
# First page
start = 0
else:
# A subsequent page, return the first item _after_ the index we locate
start = data_ids.index(request_after) + 1
data = data_list[start : start + limit]
if len(data) < limit or len(data) == 0:
# No more data, set after to None
after = None
else:
# Set after to the last item in this page of results
after = data[-1]["id"]
return httpx.Response(
status_code=status_code,
headers=headers,
json=list_response_of(data=data, before=request_after, after=after),
)
mock_class = (
AsyncMock if isinstance(http_client, AsyncHTTPClient) else MagicMock
)
mock = mock_class(side_effect=mock_function)
monkeypatch.setattr(http_client._client, "request", mock)
return request_kwargs
return inner
@pytest.fixture
def test_auto_pagination(
capture_and_mock_pagination_request_for_http_client,
) -> TestAutoPaginationFunction:
def _iterate_results_sync(
list_function: Callable[[], WorkOSListResource],
list_function_params: Optional[Mapping[str, Any]] = None,
) -> Sequence[Any]:
results = list_function(**list_function_params or {})
all_results = []
for result in results:
all_results.append(result)
return all_results
async def _iterate_results_async(
list_function: Callable[[], Awaitable[WorkOSListResource]],
list_function_params: Optional[Mapping[str, Any]] = None,
) -> Sequence[Any]:
results = await list_function(**list_function_params or {})
all_results = []
async for result in results:
all_results.append(result)
return all_results
def inner(
http_client: HTTPClient,
list_function: Union[
Callable[[], WorkOSListResource],
Callable[[], Awaitable[WorkOSListResource]],
],
expected_all_page_data: dict,
list_function_params: Optional[Mapping[str, Any]] = None,
url_path_keys: Optional[Sequence[str]] = None,
) -> None:
request_kwargs = capture_and_mock_pagination_request_for_http_client(
http_client=http_client,
data_list=expected_all_page_data,
status_code=200,
)
all_results = []
if isinstance(http_client, AsyncHTTPClient):
all_results = syncify(
_iterate_results_async(
cast(Callable[[], Awaitable[WorkOSListResource]], list_function),
list_function_params,
)
)
else:
all_results = _iterate_results_sync(
cast(Callable[[], WorkOSListResource], list_function),
list_function_params,
)
assert len(list(all_results)) == len(expected_all_page_data)
assert (list_data_to_dicts(all_results)) == expected_all_page_data
assert request_kwargs["method"] == "get"
# Validate parameters
assert "after" in request_kwargs["params"]
assert request_kwargs["params"]["limit"] == DEFAULT_LIST_RESPONSE_LIMIT
assert request_kwargs["params"]["order"] == "desc"
params = list_function_params or {}
for param in params:
if url_path_keys is not None and param not in url_path_keys:
assert request_kwargs["params"][param] == params[param]
return inner
def with_jwks_mock(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Create mock JWKS client
mock_jwks = Mock(spec=PyJWKClient)
mock_signing_key = Mock()
mock_signing_key.key = kwargs["session_constants"]["PUBLIC_KEY"]
mock_jwks.get_signing_key_from_jwt.return_value = mock_signing_key
# Apply the mock
with patch("workos.session.PyJWKClient", return_value=mock_jwks):
return func(*args, **kwargs)
return wrapper